Skip to content

Commit ab4dedf

Browse files
committed
refac: remove valid_mask since all spikes are valid
1 parent 3448452 commit ab4dedf

1 file changed

Lines changed: 13 additions & 19 deletions

File tree

src/spikeinterface/postprocessing/principal_component.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -629,41 +629,36 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx):
629629
if i0 == i1:
630630
return
631631

632+
# Since we get_traces accounting for nbefore and nafter, all spikes in the chunk are valid and we can extract
633+
# all waveforms in one go without worrying about borders.
632634
start = int(spike_times[i0] - nbefore)
633635
end = int(spike_times[i1 - 1] + nafter)
634636
traces = recording.get_traces(start_frame=start, end_frame=end, segment_index=segment_index)
635637

636638
nsamples = nbefore + nafter
637639

638640
# Extract all waveforms in the chunk at once
639-
# valid_mask tracks which spikes have valid (in-bounds) waveforms
640-
chunk_spike_times = spike_times[i0:i1]
641-
offsets = chunk_spike_times - start - nbefore
642-
valid_mask = (offsets >= 0) & (offsets + nsamples <= traces.shape[0])
641+
spike_times_in_chunk = spike_times[i0:i1]
642+
# Offset spike times to be relative to the start of the traces buffer
643+
spike_times_offset = spike_times_in_chunk - start - nbefore
644+
spike_indices = np.arange(i0, i1)
643645

644-
if not np.any(valid_mask):
645-
return
646-
647-
valid_offsets = offsets[valid_mask]
648-
valid_indices = np.arange(i0, i1)[valid_mask]
649-
n_valid = len(valid_offsets)
650-
651-
# Build waveform array: (n_valid, nsamples, n_channels)
646+
# Build waveform array: (n_spikes, nsamples, n_channels)
652647
# Use fancy indexing to extract all snippets at once
653-
sample_indices = valid_offsets[:, None] + np.arange(nsamples)[None, :] # (n_valid, nsamples)
654-
all_wfs = traces[sample_indices] # (n_valid, nsamples, n_channels)
648+
sample_indices = spike_times_offset[:, None] + np.arange(nsamples)[None, :] # (n_spikes, nsamples)
649+
all_wfs = traces[sample_indices] # (n_spikes, nsamples, n_channels)
655650

656651
# Vectorized PCA: batch by channel across all spikes in the chunk.
657652
# For each unique channel, find all spikes that use it (via their unit's
658653
# sparsity), extract waveforms, and call transform once.
659-
valid_labels = spike_labels[valid_indices]
654+
labels_in_chunk = spike_labels[spike_indices]
660655

661656
# Build a set of all channels used by spikes in this chunk
662-
unique_unit_indices = np.unique(valid_labels)
657+
unique_unit_indices = np.unique(labels_in_chunk)
663658
chan_info: dict[int, list[tuple[np.ndarray, int]]] = {}
664659
for unit_index in unique_unit_indices:
665660
chan_inds = unit_channels[unit_index]
666-
unit_mask = valid_labels == unit_index
661+
unit_mask = labels_in_chunk == unit_index
667662
unit_local_idxs = np.nonzero(unit_mask)[0]
668663
for c, chan_ind in enumerate(chan_inds):
669664
if chan_ind not in chan_info:
@@ -673,14 +668,13 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx):
673668
for chan_ind, unit_groups in chan_info.items():
674669
# Concatenate all spike indices for this channel across units
675670
all_local_idxs = np.concatenate([g[0] for g in unit_groups])
676-
global_idxs = valid_indices[all_local_idxs]
671+
global_idxs = spike_indices[all_local_idxs]
677672

678673
# Batch waveforms for this channel: (n_spikes, nsamples)
679674
wfs_batch = all_wfs[all_local_idxs, :, chan_ind]
680675

681676
if wfs_batch.size == 0:
682677
continue
683-
684678
try:
685679
pcs_batch = pca_model[chan_ind].transform(wfs_batch)
686680
# Write results back — each unit group has a fixed channel position

0 commit comments

Comments
 (0)