@@ -629,41 +629,36 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx):
629629 if i0 == i1 :
630630 return
631631
632+ # Since we get_traces accounting for nbefore and nafter, all spikes in the chunk are valid and we can extract
633+ # all waveforms in one go without worrying about borders.
632634 start = int (spike_times [i0 ] - nbefore )
633635 end = int (spike_times [i1 - 1 ] + nafter )
634636 traces = recording .get_traces (start_frame = start , end_frame = end , segment_index = segment_index )
635637
636638 nsamples = nbefore + nafter
637639
638640 # Extract all waveforms in the chunk at once
639- # valid_mask tracks which spikes have valid (in-bounds) waveforms
640- chunk_spike_times = spike_times [ i0 : i1 ]
641- offsets = chunk_spike_times - start - nbefore
642- valid_mask = ( offsets >= 0 ) & ( offsets + nsamples <= traces . shape [ 0 ] )
641+ spike_times_in_chunk = spike_times [ i0 : i1 ]
642+ # Offset spike times to be relative to the start of the traces buffer
643+ spike_times_offset = spike_times_in_chunk - start - nbefore
644+ spike_indices = np . arange ( i0 , i1 )
643645
644- if not np .any (valid_mask ):
645- return
646-
647- valid_offsets = offsets [valid_mask ]
648- valid_indices = np .arange (i0 , i1 )[valid_mask ]
649- n_valid = len (valid_offsets )
650-
651- # Build waveform array: (n_valid, nsamples, n_channels)
646+ # Build waveform array: (n_spikes, nsamples, n_channels)
652647 # Use fancy indexing to extract all snippets at once
653- sample_indices = valid_offsets [:, None ] + np .arange (nsamples )[None , :] # (n_valid , nsamples)
654- all_wfs = traces [sample_indices ] # (n_valid , nsamples, n_channels)
648+ sample_indices = spike_times_offset [:, None ] + np .arange (nsamples )[None , :] # (n_spikes , nsamples)
649+ all_wfs = traces [sample_indices ] # (n_spikes , nsamples, n_channels)
655650
656651 # Vectorized PCA: batch by channel across all spikes in the chunk.
657652 # For each unique channel, find all spikes that use it (via their unit's
658653 # sparsity), extract waveforms, and call transform once.
659- valid_labels = spike_labels [valid_indices ]
654+ labels_in_chunk = spike_labels [spike_indices ]
660655
661656 # Build a set of all channels used by spikes in this chunk
662- unique_unit_indices = np .unique (valid_labels )
657+ unique_unit_indices = np .unique (labels_in_chunk )
663658 chan_info : dict [int , list [tuple [np .ndarray , int ]]] = {}
664659 for unit_index in unique_unit_indices :
665660 chan_inds = unit_channels [unit_index ]
666- unit_mask = valid_labels == unit_index
661+ unit_mask = labels_in_chunk == unit_index
667662 unit_local_idxs = np .nonzero (unit_mask )[0 ]
668663 for c , chan_ind in enumerate (chan_inds ):
669664 if chan_ind not in chan_info :
@@ -673,14 +668,13 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx):
673668 for chan_ind , unit_groups in chan_info .items ():
674669 # Concatenate all spike indices for this channel across units
675670 all_local_idxs = np .concatenate ([g [0 ] for g in unit_groups ])
676- global_idxs = valid_indices [all_local_idxs ]
671+ global_idxs = spike_indices [all_local_idxs ]
677672
678673 # Batch waveforms for this channel: (n_spikes, nsamples)
679674 wfs_batch = all_wfs [all_local_idxs , :, chan_ind ]
680675
681676 if wfs_batch .size == 0 :
682677 continue
683-
684678 try :
685679 pcs_batch = pca_model [chan_ind ].transform (wfs_batch )
686680 # Write results back — each unit group has a fixed channel position
0 commit comments