@@ -151,6 +151,8 @@ def __init__(
151151
152152 self .sparse_templates_array_static = None
153153
154+ self .dtype = self .sparse_templates_array_moved .dtype
155+
154156 # interpolation bins edges
155157 self .interpolation_time_bins_s = []
156158 self .interpolation_time_bin_edges_s = []
@@ -171,6 +173,7 @@ def __init__(
171173 self .interpolation_time_bins_s = None
172174 self .interpolation_time_bin_edges_s = None
173175 self .sparse_templates_array_static = templates .templates_array
176+ self .dtype = self .sparse_templates_array_static .dtype
174177
175178 extremum_chan = get_template_extremum_channel (templates , peak_sign = peak_sign , outputs = "index" )
176179 # as numpy vector
@@ -271,6 +274,7 @@ def get_trace_margin(self):
271274 def compute_matching (self , traces , start_frame , end_frame , segment_index ):
272275
273276 # TODO check if this is usefull
277+ traces = traces .astype (self .dtype )
274278 residuals = traces .copy ()
275279
276280 if self .motion_aware :
@@ -375,10 +379,20 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index):
375379 if spikes .size > 0 :
376380 spikes_in_time_bin .append (spikes )
377381
382+ # # DEBUG
383+ # if spikes.size != np.unique(spikes).size:
384+ # print('In loop double spikes', spikes.size, np.unique(spikes).size)
385+ # spikes2 = spikes.copy()
386+ # order = np.argsort(spikes2['sample_index'])
387+ # spikes2 = spikes2[order]
388+ # print(spikes2)
389+ # print()
390+
378391 level += 1
379392
380393 # TODO concatenate all spikes for this instead of prev loop
381- spikes_prev_loop = spikes
394+ # spikes_prev_loop = spikes
395+ spikes_prev_loop = np .concatenate ((spikes_prev_loop , spikes ))
382396
383397 if (spikes .size == 0 ) or (level == self .max_peeler_loop ):
384398 if self .use_fine_detector and not use_fine_detector_level :
@@ -401,6 +415,22 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index):
401415 else :
402416 all_spikes = np .zeros (0 , dtype = _base_matching_dtype )
403417
418+ # DEBUG
419+ # if all_spikes.size != np.unique(all_spikes).size:
420+ # print('After loop double spikes', all_spikes.size, np.unique(all_spikes).size)
421+ # all_spikes2 = all_spikes.copy()
422+ # order = np.argsort(all_spikes2['sample_index'])
423+ # all_spikes2 = all_spikes2[order]
424+ # inds = np.flatnonzero(np.diff(all_spikes2[order]['sample_index']) == 0)
425+ # keep = np.zeros(all_spikes2.size, dtype='bool')
426+ # keep[inds] = 1
427+ # keep[inds+1] = 1
428+ # print(all_spikes2[keep])
429+ # # print(all_spikes2)
430+
431+ # import time
432+ # time.sleep(0.5)
433+
404434 return all_spikes
405435
406436 def _find_spikes_one_level (self , traces , spikes_prev_loop , use_fine_detector , level , channel_motions ):
@@ -578,6 +608,7 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, le
578608 if low_lim <= amp <= up_lim :
579609 spikes ["amplitude" ][i ] = amp
580610 wanted_channel_mask = np .ones (traces .shape [1 ], dtype = bool ) # TODO move this before the loop
611+ # TODO check dtype are the same
581612 construct_prediction_sparse (
582613 spikes [i : i + 1 ],
583614 traces ,
0 commit comments