Skip to content

Commit c08e933

Browse files
tayheautayheaupre-commit-ci[bot]samuelgarciaalejoe91
authored
Random spike selection new methods (#4276)
Co-authored-by: tayheau <thopsore@WD25-1022.corp.pasteur.fr> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Samuel Garcia <sam.garcia.die@gmail.com> Co-authored-by: Alessio Buccino <alejoe9187@gmail.com> Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com>
1 parent ea2ef8c commit c08e933

4 files changed

Lines changed: 68 additions & 23 deletions

File tree

src/spikeinterface/core/analyzer_extension_core.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,18 @@ class ComputeRandomSpikes(AnalyzerExtension):
3535
3636
Parameters
3737
----------
38-
method : "uniform" | "all", default: "uniform"
39-
The method to select the spikes
38+
method: "uniform" | "percentage" | "maximum_rate" | "all" , default: "uniform"
39+
Method to select spikes: "uniform" randomly up to max_spikes_per_unit, "percentage" selects a fraction of spikes, and "maximum_rate" limits selection by spike rate over time.
4040
max_spikes_per_unit : int, default: 500
4141
The maximum number of spikes per unit, ignored if method="all"
4242
margin_size : int, default: None
4343
A margin on each border of segments to avoid border spikes, ignored if method="all"
4444
seed : int or None, default: None
4545
A seed for the random generator, ignored if method="all"
46+
percentage: float | None, default: None
47+
In case of `percentage` method. The proportion of spikes per units.
48+
maximum_rate: float | None, default: None
49+
In case of `maximum_rate` method. The cap rate per units.
4650
4751
Returns
4852
-------
@@ -64,7 +68,9 @@ def _run(self, verbose=False):
6468
**self.params,
6569
)
6670

67-
def _set_params(self, method="uniform", max_spikes_per_unit=500, margin_size=None, seed=None):
71+
def _set_params(
72+
self, method="uniform", max_spikes_per_unit=500, margin_size=None, seed=None, percentage=None, maximum_rate=None
73+
):
6874
params = dict(method=method, max_spikes_per_unit=max_spikes_per_unit, margin_size=margin_size, seed=seed)
6975
return params
7076

src/spikeinterface/core/sorting_tools.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import warnings
22
import importlib.util
33

4+
from typing import Literal
5+
46
import numpy as np
57

68
from spikeinterface.core.base import BaseExtractor, unit_period_dtype
@@ -146,14 +148,16 @@ def vector_to_list_of_spiketrain_numba(sample_indices, unit_indices, num_units):
146148
return vector_to_list_of_spiketrain_numba
147149

148150

149-
# TODO later : implement other method like "maximum_rate", "by_percent", ...
151+
# stratified sampling (isi / amplitude / pca distance ? )
150152
def random_spikes_selection(
151153
sorting: BaseSorting,
152-
num_samples: int | None = None,
153-
method: str = "uniform",
154+
num_samples: list[int] | None = None,
155+
method: Literal["uniform", "all", "percentage", "maximum_rate"] = "uniform",
154156
max_spikes_per_unit: int = 500,
155157
margin_size: int | None = None,
156158
seed: int | None = None,
159+
percentage: float | None = None,
160+
maximum_rate: float | None = None,
157161
):
158162
"""
159163
This replaces `select_random_spikes_uniformly()`.
@@ -165,41 +169,57 @@ def random_spikes_selection(
165169
----------
166170
sorting: BaseSorting
167171
The sorting object
168-
num_samples: list of int
172+
num_samples: list[int] | None, default: None
169173
The number of samples per segment.
170174
Can be retrieved from recording with
171175
num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())]
172-
method: "uniform" | "all", default: "uniform"
173-
The method to use. Only "uniform" is implemented for now
176+
method: "uniform" | "percentage" | "maximum_rate" | "all" , default: "uniform"
177+
Method to select spikes: "uniform" randomly up to max_spikes_per_unit, "percentage" selects a fraction of spikes, and "maximum_rate" limits selection by spike rate over time.
174178
max_spikes_per_unit: int, default: 500
175-
The number of spikes per units
179+
The maximum number of spikes per units
176180
margin_size: None | int, default: None
177181
A margin on each border of segments to avoid border spikes
178182
seed: None | int, default: None
179183
A seed for random generator
184+
percentage: float | None, default: None
185+
In case of `percentage` method. The proportion of spikes per units.
186+
maximum_rate: float | None, default: None
187+
In case of `maximum_rate` method. The cap rate per units.
180188
181189
Returns
182190
-------
183191
random_spikes_indices: np.array
184192
Selected spike indices coresponding to the sorting spike vector.
185193
"""
194+
rng_methods = ("uniform", "percentage", "maximum_rate")
195+
196+
if method == "all":
197+
spikes = sorting.to_spike_vector()
198+
random_spikes_indices = np.arange(spikes.size)
199+
200+
elif method in rng_methods:
201+
from spikeinterface.widgets.utils import get_segment_durations
186202

187-
if method == "uniform":
188203
rng = np.random.default_rng(seed=seed)
189204

205+
# since un concatenated
206+
# spikes = [ [ (sample_index, unit_index, segment_index), (), ... ], [ (), ... ]]
190207
spikes = sorting.to_spike_vector(concatenated=False)
191208
cum_sizes = np.cumsum([0] + [s.size for s in spikes])
192209

193-
# this fast when numba
210+
# this is fast when numba is installed
194211
spike_indices = spike_vector_to_indices(spikes, sorting.unit_ids, absolute_index=False)
195212

196213
random_spikes_indices = []
197214
for unit_index, unit_id in enumerate(sorting.unit_ids):
198215
all_unit_indices = []
199216
for segment_index in range(sorting.get_num_segments()):
200-
# this is local index
217+
# this is local segment index
201218
inds_in_seg = spike_indices[segment_index][unit_id]
202219
if margin_size is not None:
220+
if num_samples is None:
221+
raise ValueError("num_samples must be provided when margin_size is used")
222+
203223
local_spikes = spikes[segment_index][inds_in_seg]
204224
mask = (local_spikes["sample_index"] >= margin_size) & (
205225
local_spikes["sample_index"] < (num_samples[segment_index] - margin_size)
@@ -209,19 +229,33 @@ def random_spikes_selection(
209229
inds_in_seg_abs = inds_in_seg + cum_sizes[segment_index]
210230
all_unit_indices.append(inds_in_seg_abs)
211231
all_unit_indices = np.concatenate(all_unit_indices)
212-
selected_unit_indices = rng.choice(
213-
all_unit_indices, size=min(max_spikes_per_unit, all_unit_indices.size), replace=False, shuffle=False
214-
)
232+
233+
if method == "uniform":
234+
rng_size = min(max_spikes_per_unit, all_unit_indices.size)
235+
selected_unit_indices = rng.choice(all_unit_indices, size=rng_size, replace=False, shuffle=False)
236+
237+
elif method == "percentage":
238+
if percentage is None or not (0 < percentage <= 1):
239+
raise ValueError(f"percentage must be in the interval (0, 1]")
240+
241+
rng_size = min(max_spikes_per_unit, int(all_unit_indices.size * percentage))
242+
selected_unit_indices = rng.choice(all_unit_indices, size=rng_size, replace=False, shuffle=False)
243+
244+
elif method == "maximum_rate":
245+
if maximum_rate is None:
246+
raise ValueError(f"maximum_rate must be defined")
247+
248+
t_duration = np.sum(get_segment_durations(sorting))
249+
rng_size = min(int(t_duration * maximum_rate), max_spikes_per_unit, all_unit_indices.size)
250+
selected_unit_indices = rng.choice(all_unit_indices, size=rng_size, replace=False, shuffle=False)
251+
215252
random_spikes_indices.append(selected_unit_indices)
216253

217254
random_spikes_indices = np.concatenate(random_spikes_indices)
218255
random_spikes_indices = np.sort(random_spikes_indices)
219256

220-
elif method == "all":
221-
spikes = sorting.to_spike_vector()
222-
random_spikes_indices = np.arange(spikes.size)
223257
else:
224-
raise ValueError(f"random_spikes_selection(): method must be 'all' or 'uniform'")
258+
raise ValueError(f"random_spikes_selection(): method must be 'all' or any in {', '.join(rng_methods)}")
225259

226260
return random_spikes_indices
227261

src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,11 @@ def fit(
9494
model_folder_path: str,
9595
detect_peaks_params: dict,
9696
peak_selection_params: dict,
97-
job_kwargs: dict = None,
97+
job_kwargs: dict | None = None,
9898
ms_before: float = 1.0,
9999
ms_after: float = 1.0,
100100
whiten: bool = True,
101-
radius_um: float = None,
101+
radius_um: float | None = None,
102102
) -> "IncrementalPCA":
103103
"""
104104
Train a pca model using the data in the recording object and the parameters provided.

src/spikeinterface/widgets/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,20 +399,25 @@ def validate_segment_indices(segment_indices: list[int] | None, sorting: BaseSor
399399
return segment_indices
400400

401401

402-
def get_segment_durations(sorting: BaseSorting, segment_indices: list[int]) -> list[float]:
402+
def get_segment_durations(sorting: BaseSorting, segment_indices: list[int] = None) -> list[float]:
403403
"""
404404
Calculate the duration of each segment in a sorting object.
405405
406406
Parameters
407407
----------
408408
sorting : BaseSorting
409409
The sorting object containing spike data
410+
segment_indices : list[int] | None
411+
List of the segment indices to process. Default to None.
410412
411413
Returns
412414
-------
413415
list[float]
414416
List of segment durations in seconds
415417
"""
418+
if segment_indices is None:
419+
segment_indices = range(sorting.get_num_segments())
420+
416421
spikes = sorting.to_spike_vector()
417422

418423
segment_boundaries = [

0 commit comments

Comments
 (0)