Skip to content

Commit dc4ec1e

Browse files
ygersamuelgarciapre-commit-ci[bot]
authored
Optimizations: faster waveforms/templates and seeding iterative_isosplit properly (#4402)
Co-authored-by: Samuel Garcia <sam.garcia.die@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1a129ad commit dc4ec1e

2 files changed

Lines changed: 40 additions & 26 deletions

File tree

src/spikeinterface/core/waveform_tools.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -381,14 +381,19 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_dic
381381
l1 = i1 + s0
382382

383383
if l1 > l0:
384-
start = spikes[l0]["sample_index"] - nbefore
385-
end = spikes[l1 - 1]["sample_index"] + nafter
384+
385+
sub_spikes = in_seg_spikes[i0:i1]
386+
start = sub_spikes[0]["sample_index"] - nbefore
387+
end = sub_spikes[-1]["sample_index"] + nafter
386388

387389
# load trace in memory
388390
traces = recording.get_traces(
389391
start_frame=start, end_frame=end, segment_index=segment_index, return_in_uV=return_in_uV
390392
)
391393

394+
onset = start + nbefore
395+
offset = nbefore + nafter
396+
392397
for unit_ind, unit_id in enumerate(unit_ids):
393398
# find pos
394399
inds = inds_by_unit[unit_id]
@@ -404,8 +409,8 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_dic
404409
wfs = worker_dict["waveforms_by_units"][unit_id]
405410

406411
for pos in in_chunk_pos:
407-
sample_index = spikes[inds[pos]]["sample_index"]
408-
wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :]
412+
sample_index = spikes["sample_index"][inds[pos]] - onset
413+
wf = traces[sample_index : sample_index + offset, :]
409414

410415
if sparsity_mask is None:
411416
wfs[pos, :, :] = wf
@@ -639,23 +644,24 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work
639644
in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)]
640645
)
641646

642-
# slice in absolut in spikes vector
643-
l0 = i0 + s0
644-
l1 = i1 + s0
645-
646-
if l1 > l0:
647-
start = spikes[l0]["sample_index"] - nbefore
648-
end = spikes[l1 - 1]["sample_index"] + nafter
647+
if i1 > i0:
648+
sub_spikes = in_seg_spikes[i0:i1]
649+
start = sub_spikes[0]["sample_index"] - nbefore
650+
end = sub_spikes[-1]["sample_index"] + nafter
649651

650652
# load trace in memory
651653
traces = recording.get_traces(
652654
start_frame=start, end_frame=end, segment_index=segment_index, return_in_uV=return_in_uV
653655
)
654656

655-
for spike_index in range(l0, l1):
656-
sample_index = spikes[spike_index]["sample_index"]
657-
unit_index = spikes[spike_index]["unit_index"]
658-
wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :]
657+
onset = start + nbefore
658+
offset = nbefore + nafter
659+
sample_indices = sub_spikes["sample_index"] - onset
660+
unit_indices = sub_spikes["unit_index"]
661+
spike_indices = s0 + np.arange(i0, i1)
662+
663+
for sample_index, unit_index, spike_index in zip(sample_indices, unit_indices, spike_indices):
664+
wf = traces[sample_index : sample_index + offset, :]
659665

660666
if sparsity_mask is None:
661667
all_waveforms[spike_index, :, :] = wf
@@ -1055,23 +1061,25 @@ def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_dic
10551061
in_seg_spikes["sample_index"], [max(start_frame, nbefore), min(end_frame, seg_size - nafter)]
10561062
)
10571063

1058-
# slice in absolut in spikes vector
1059-
l0 = i0 + s0
1060-
l1 = i1 + s0
1064+
if i1 > i0:
1065+
sub_spikes = in_seg_spikes[i0:i1]
10611066

1062-
if l1 > l0:
1063-
start = spikes[l0]["sample_index"] - nbefore
1064-
end = spikes[l1 - 1]["sample_index"] + nafter
1067+
start = sub_spikes[0]["sample_index"] - nbefore
1068+
end = sub_spikes[-1]["sample_index"] + nafter
10651069

10661070
# load trace in memory
10671071
traces = recording.get_traces(
10681072
start_frame=start, end_frame=end, segment_index=segment_index, return_in_uV=return_in_uV
10691073
)
10701074

1071-
for spike_index in range(l0, l1):
1072-
sample_index = spikes[spike_index]["sample_index"]
1073-
unit_index = spikes[spike_index]["unit_index"]
1074-
wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :]
1075+
onset = start + nbefore
1076+
offset = nbefore + nafter
1077+
sample_indices = sub_spikes["sample_index"] - onset
1078+
unit_indices = sub_spikes["unit_index"]
1079+
1080+
for sample_index, unit_index in zip(sample_indices, unit_indices):
1081+
1082+
wf = traces[sample_index : sample_index + offset, :]
10751083

10761084
if sparsity_mask is None:
10771085
waveform_accumulator_per_worker[worker_index, unit_index, :, :] += wf

src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
116116
debug_folder = params["debug_folder"]
117117

118118
params_peak_svd = params["peaks_svd"].copy()
119-
params_peak_svd["seed"] = params["seed"]
119+
seed = params["seed"]
120+
params_peak_svd["seed"] = seed
120121
motion = params_peak_svd["motion"]
121122
motion_aware = motion is not None
122123

@@ -136,9 +137,14 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
136137
else:
137138
peaks_svd, sparse_mask, svd_model = outs
138139

140+
# Clustering: channel index > split > merge
139141
# Clustering: channel index > split > merge
140142
split_params = params["split"].copy()
141143

144+
if seed is not None:
145+
params_peak_svd.update(seed=seed)
146+
split_params["method_kwargs"].update(seed=seed)
147+
142148
split_radius_um = split_params.pop("split_radius_um")
143149
neighbours_mask = get_channel_distances(recording) <= split_radius_um
144150
split_params["method_kwargs"]["neighbours_mask"] = neighbours_mask

0 commit comments

Comments
 (0)