Skip to content

Commit 1f0339f

Browse files
authored
Set amplitude scalings equal zero for zero template units (#4455)
1 parent f6e45ce commit 1f0339f

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

src/spikeinterface/postprocessing/amplitude_scalings.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@ def __init__(
201201
max_margin_collisions = delta_collision_samples + margin_waveforms
202202
self._margin = max_margin_collisions
203203

204+
# for some edge cases a template can be zero, leading to problems later
205+
template_is_zero = [np.all(template == 0) for template in all_templates]
206+
204207
self._all_templates = all_templates
205208
self._sparsity_mask = sparsity_mask
206209
self._nbefore = nbefore
@@ -209,6 +212,7 @@ def __init__(
209212
self._cut_out_after = cut_out_after
210213
self._handle_collisions = handle_collisions
211214
self._delta_collision_samples = delta_collision_samples
215+
self._template_is_zero = template_is_zero
212216

213217
self._kwargs.update(
214218
all_templates=all_templates,
@@ -220,6 +224,7 @@ def __init__(
220224
return_in_uV=return_in_uV,
221225
handle_collisions=handle_collisions,
222226
delta_collision_samples=delta_collision_samples,
227+
template_is_zero=template_is_zero,
223228
)
224229

225230
def get_dtype(self):
@@ -239,6 +244,7 @@ def compute(self, traces, peaks):
239244
cut_out_after = self._cut_out_after
240245
handle_collisions = self._handle_collisions
241246
delta_collision_samples = self._delta_collision_samples
247+
template_is_zero = self._template_is_zero
242248

243249
# local_spikes_within_margin = peaks
244250
# i0 = np.searchsorted(local_spikes_within_margin["sample_index"], left_margin)
@@ -265,7 +271,14 @@ def compute(self, traces, peaks):
265271
if spike_index in collisions.keys():
266272
# we deal with overlapping spikes later
267273
continue
274+
268275
unit_index = spike["unit_index"]
276+
277+
if template_is_zero[unit_index]:
278+
# if template is zero, linregress will fail so we intervene
279+
scalings[spike_index] = 0
280+
continue
281+
269282
sample_centered = spike["sample_index"]
270283
(sparse_indices,) = np.nonzero(sparsity_mask[unit_index])
271284
template = all_templates[unit_index][:, sparse_indices]

0 commit comments

Comments
 (0)