@@ -381,14 +381,19 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_dic
381381 l1 = i1 + s0
382382
383383 if l1 > l0 :
384- start = spikes [l0 ]["sample_index" ] - nbefore
385- end = spikes [l1 - 1 ]["sample_index" ] + nafter
384+
385+ sub_spikes = in_seg_spikes [i0 :i1 ]
386+ start = sub_spikes [0 ]["sample_index" ] - nbefore
387+ end = sub_spikes [- 1 ]["sample_index" ] + nafter
386388
387389 # load trace in memory
388390 traces = recording .get_traces (
389391 start_frame = start , end_frame = end , segment_index = segment_index , return_in_uV = return_in_uV
390392 )
391393
394+ onset = start + nbefore
395+ offset = nbefore + nafter
396+
392397 for unit_ind , unit_id in enumerate (unit_ids ):
393398 # find pos
394399 inds = inds_by_unit [unit_id ]
@@ -404,8 +409,8 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_dic
404409 wfs = worker_dict ["waveforms_by_units" ][unit_id ]
405410
406411 for pos in in_chunk_pos :
407- sample_index = spikes [inds [pos ]][ "sample_index" ]
408- wf = traces [sample_index - start - nbefore : sample_index - start + nafter , :]
412+ sample_index = spikes ["sample_index" ][ inds [pos ]] - onset
413+ wf = traces [sample_index : sample_index + offset , :]
409414
410415 if sparsity_mask is None :
411416 wfs [pos , :, :] = wf
@@ -639,23 +644,24 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work
639644 in_seg_spikes ["sample_index" ], [max (start_frame , nbefore ), min (end_frame , seg_size - nafter )]
640645 )
641646
642- # slice in absolut in spikes vector
643- l0 = i0 + s0
644- l1 = i1 + s0
645-
646- if l1 > l0 :
647- start = spikes [l0 ]["sample_index" ] - nbefore
648- end = spikes [l1 - 1 ]["sample_index" ] + nafter
647+ if i1 > i0 :
648+ sub_spikes = in_seg_spikes [i0 :i1 ]
649+ start = sub_spikes [0 ]["sample_index" ] - nbefore
650+ end = sub_spikes [- 1 ]["sample_index" ] + nafter
649651
650652 # load trace in memory
651653 traces = recording .get_traces (
652654 start_frame = start , end_frame = end , segment_index = segment_index , return_in_uV = return_in_uV
653655 )
654656
655- for spike_index in range (l0 , l1 ):
656- sample_index = spikes [spike_index ]["sample_index" ]
657- unit_index = spikes [spike_index ]["unit_index" ]
658- wf = traces [sample_index - start - nbefore : sample_index - start + nafter , :]
657+ onset = start + nbefore
658+ offset = nbefore + nafter
659+ sample_indices = sub_spikes ["sample_index" ] - onset
660+ unit_indices = sub_spikes ["unit_index" ]
661+ spike_indices = s0 + np .arange (i0 , i1 )
662+
663+ for sample_index , unit_index , spike_index in zip (sample_indices , unit_indices , spike_indices ):
664+ wf = traces [sample_index : sample_index + offset , :]
659665
660666 if sparsity_mask is None :
661667 all_waveforms [spike_index , :, :] = wf
@@ -1055,23 +1061,25 @@ def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_dic
10551061 in_seg_spikes ["sample_index" ], [max (start_frame , nbefore ), min (end_frame , seg_size - nafter )]
10561062 )
10571063
1058- # slice in absolut in spikes vector
1059- l0 = i0 + s0
1060- l1 = i1 + s0
1064+ if i1 > i0 :
1065+ sub_spikes = in_seg_spikes [i0 :i1 ]
10611066
1062- if l1 > l0 :
1063- start = spikes [l0 ]["sample_index" ] - nbefore
1064- end = spikes [l1 - 1 ]["sample_index" ] + nafter
1067+ start = sub_spikes [0 ]["sample_index" ] - nbefore
1068+ end = sub_spikes [- 1 ]["sample_index" ] + nafter
10651069
10661070 # load trace in memory
10671071 traces = recording .get_traces (
10681072 start_frame = start , end_frame = end , segment_index = segment_index , return_in_uV = return_in_uV
10691073 )
10701074
1071- for spike_index in range (l0 , l1 ):
1072- sample_index = spikes [spike_index ]["sample_index" ]
1073- unit_index = spikes [spike_index ]["unit_index" ]
1074- wf = traces [sample_index - start - nbefore : sample_index - start + nafter , :]
1075+ onset = start + nbefore
1076+ offset = nbefore + nafter
1077+ sample_indices = sub_spikes ["sample_index" ] - onset
1078+ unit_indices = sub_spikes ["unit_index" ]
1079+
1080+ for sample_index , unit_index in zip (sample_indices , unit_indices ):
1081+
1082+ wf = traces [sample_index : sample_index + offset , :]
10751083
10761084 if sparsity_mask is None :
10771085 waveform_accumulator_per_worker [worker_index , unit_index , :, :] += wf
0 commit comments