@@ -201,6 +201,9 @@ def __init__(
201201 max_margin_collisions = delta_collision_samples + margin_waveforms
202202 self ._margin = max_margin_collisions
203203
204+ # for some edge cases a template can be zero, leading to problems later
205+ template_is_zero = [np .all (template == 0 ) for template in all_templates ]
206+
204207 self ._all_templates = all_templates
205208 self ._sparsity_mask = sparsity_mask
206209 self ._nbefore = nbefore
@@ -209,6 +212,7 @@ def __init__(
209212 self ._cut_out_after = cut_out_after
210213 self ._handle_collisions = handle_collisions
211214 self ._delta_collision_samples = delta_collision_samples
215+ self ._template_is_zero = template_is_zero
212216
213217 self ._kwargs .update (
214218 all_templates = all_templates ,
@@ -220,6 +224,7 @@ def __init__(
220224 return_in_uV = return_in_uV ,
221225 handle_collisions = handle_collisions ,
222226 delta_collision_samples = delta_collision_samples ,
227+ template_is_zero = template_is_zero ,
223228 )
224229
225230 def get_dtype (self ):
@@ -239,6 +244,7 @@ def compute(self, traces, peaks):
239244 cut_out_after = self ._cut_out_after
240245 handle_collisions = self ._handle_collisions
241246 delta_collision_samples = self ._delta_collision_samples
247+ template_is_zero = self ._template_is_zero
242248
243249 # local_spikes_within_margin = peaks
244250 # i0 = np.searchsorted(local_spikes_within_margin["sample_index"], left_margin)
@@ -265,7 +271,14 @@ def compute(self, traces, peaks):
265271 if spike_index in collisions .keys ():
266272 # we deal with overlapping spikes later
267273 continue
274+
268275 unit_index = spike ["unit_index" ]
276+
277+ if template_is_zero [unit_index ]:
278+ # if template is zero, linregress will fail so we intervene
279+ scalings [spike_index ] = 0
280+ continue
281+
269282 sample_centered = spike ["sample_index" ]
270283 (sparse_indices ,) = np .nonzero (sparsity_mask [unit_index ])
271284 template = all_templates [unit_index ][:, sparse_indices ]
0 commit comments