Skip to content

Commit 90a3fc6

Browse files
authored
Merge branch 'main' into fix_loading_sync_stream_in_spikeinterface
2 parents 5d17669 + 604c049 commit 90a3fc6

65 files changed

Lines changed: 804 additions & 433 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

examples/tutorials/core/plot_4_sorting_analyzer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,8 @@
4343
##############################################################################
4444
# Let's now instantiate the recording and sorting objects:
4545

46-
recording = se.MEArecRecordingExtractor(local_path)
46+
recording, sorting = se.read_mearec(local_path)
4747
print(recording)
48-
sorting = se.MEArecSortingExtractor(local_path)
4948
print(sorting)
5049

5150
###############################################################################

examples/tutorials/extractors/plot_1_read_various_formats.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
# :py:class:`~spikeinterface.extractors.Spike2RecordingExtractor` object:
6262
#
6363

64-
recording = se.Spike2RecordingExtractor(spike2_file_path, stream_id="0")
64+
recording = se.read_spike2(spike2_file_path, stream_id="0")
6565
print(recording)
6666

6767
##############################################################################
@@ -75,11 +75,6 @@
7575
print(sorting)
7676
print(type(sorting))
7777

78-
##############################################################################
79-
# The :py:func:`~spikeinterface.extractors.read_mearec` function is equivalent to:
80-
81-
recording = se.MEArecRecordingExtractor(mearec_folder_path)
82-
sorting = se.MEArecSortingExtractor(mearec_folder_path)
8378

8479
##############################################################################
8580
# SI objects (:py:class:`~spikeinterface.core.BaseRecording` and :py:class:`~spikeinterface.core.BaseSorting`)

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ qualitymetrics = [
120120
]
121121

122122
test_core = [
123-
"pytest",
123+
"pytest<8.4.0",
124124
"pytest-dependency",
125125
"psutil",
126126

@@ -146,7 +146,7 @@ test_preprocessing = [
146146

147147

148148
test = [
149-
"pytest",
149+
"pytest<8.4.0",
150150
"pytest-dependency",
151151
"pytest-cov",
152152
"psutil",

src/spikeinterface/benchmark/benchmark_base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,6 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None):
134134
else:
135135
analyzer = data
136136

137-
rec, gt_sorting = analyzer.recording, analyzer.sorting
138-
139137
analyzers_path[key] = str(analyzer.folder.resolve())
140138

141139
# recordings are pickled
@@ -180,7 +178,11 @@ def scan_folder(self):
180178
self.analyzers[key] = analyzer
181179
# the sorting is in memory here we take the saved one because comparisons need to pickle it later
182180
sorting = load(analyzer.folder / "sorting")
183-
self.datasets[key] = analyzer.recording, sorting
181+
if analyzer.has_recording():
182+
recording = analyzer.recording
183+
else:
184+
recording = None
185+
self.datasets[key] = recording, sorting
184186

185187
with open(self.folder / "cases.pickle", "rb") as f:
186188
self.cases = pickle.load(f)

src/spikeinterface/core/analyzer_extension_core.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def _run(self, verbose=False, **job_kwargs):
194194
self.nbefore,
195195
self.nafter,
196196
mode=mode,
197-
return_scaled=self.sorting_analyzer.return_scaled,
197+
return_in_uV=self.sorting_analyzer.return_in_uV,
198198
file_path=file_path,
199199
dtype=self.params["dtype"],
200200
sparsity_mask=sparsity_mask,
@@ -216,7 +216,7 @@ def _set_params(
216216
if dtype is None:
217217
dtype = recording.get_dtype()
218218

219-
if np.issubdtype(dtype, np.integer) and self.sorting_analyzer.return_scaled:
219+
if np.issubdtype(dtype, np.integer) and self.sorting_analyzer.return_in_uV:
220220
dtype = "float32"
221221

222222
dtype = np.dtype(dtype)
@@ -427,7 +427,7 @@ def _run(self, verbose=False, **job_kwargs):
427427
# retrieve spike vector and the sampling
428428
some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes()
429429

430-
return_scaled = self.sorting_analyzer.return_scaled
430+
return_in_uV = self.sorting_analyzer.return_in_uV
431431

432432
return_std = "std" in self.params["operators"]
433433
output = estimate_templates_with_accumulator(
@@ -436,7 +436,7 @@ def _run(self, verbose=False, **job_kwargs):
436436
unit_ids,
437437
self.nbefore,
438438
self.nafter,
439-
return_scaled=return_scaled,
439+
return_in_uV=return_in_uV,
440440
return_std=return_std,
441441
verbose=verbose,
442442
**job_kwargs,
@@ -648,7 +648,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
648648
channel_ids=self.sorting_analyzer.channel_ids,
649649
unit_ids=unit_ids,
650650
probe=self.sorting_analyzer.get_probe(),
651-
is_scaled=self.sorting_analyzer.return_scaled,
651+
is_scaled=self.sorting_analyzer.return_in_uV,
652652
)
653653
else:
654654
raise ValueError("`outputs` must be 'numpy' or 'Templates'")
@@ -732,7 +732,7 @@ def _merge_extension_data(
732732
def _run(self, verbose=False, **job_kwargs):
733733
self.data["noise_levels"] = get_noise_levels(
734734
self.sorting_analyzer.recording,
735-
return_scaled=self.sorting_analyzer.return_scaled,
735+
return_in_uV=self.sorting_analyzer.return_in_uV,
736736
**self.params,
737737
**job_kwargs,
738738
)

src/spikeinterface/core/baserecording.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,8 @@ def get_traces(
295295
end_frame: int | None = None,
296296
channel_ids: list | np.array | tuple | None = None,
297297
order: "C" | "F" | None = None,
298-
return_scaled: bool = False,
298+
return_scaled: bool | None = None,
299+
return_in_uV: bool = False,
299300
cast_unsigned: bool = False,
300301
) -> np.ndarray:
301302
"""Returns traces from recording.
@@ -312,7 +313,11 @@ def get_traces(
312313
The channel ids. If None, all channels are used, default: None
313314
order : "C" | "F" | None, default: None
314315
The order of the traces ("C" | "F"). If None, traces are returned as they are
315-
return_scaled : bool, default: False
316+
return_scaled : bool | None, default: None
317+
DEPRECATED. Use return_in_uV instead.
318+
If True and the recording has scaling (gain_to_uV and offset_to_uV properties),
319+
traces are scaled to uV
320+
return_in_uV : bool, default: False
316321
If True and the recording has scaling (gain_to_uV and offset_to_uV properties),
317322
traces are scaled to uV
318323
cast_unsigned : bool, default: False
@@ -327,7 +332,7 @@ def get_traces(
327332
Raises
328333
------
329334
ValueError
330-
If return_scaled is True, but recording does not have scaled traces
335+
If return_in_uV is True, but recording does not have scaled traces
331336
"""
332337
segment_index = self._check_segment_index(segment_index)
333338
channel_indices = self.ids_to_indices(channel_ids, prefer_slice=True)
@@ -351,15 +356,24 @@ def get_traces(
351356
traces = traces.astype(f"int{2 * (dtype.itemsize) * 8}") - 2 ** (nbits - 1)
352357
traces = traces.astype(f"int{dtype.itemsize * 8}")
353358

354-
if return_scaled:
359+
# Handle deprecated return_scaled parameter
360+
if return_scaled is not None:
361+
warnings.warn(
362+
"`return_scaled` is deprecated and will be removed in a future version. Use `return_in_uV` instead.",
363+
category=DeprecationWarning,
364+
stacklevel=2,
365+
)
366+
return_in_uV = return_scaled
367+
368+
if return_in_uV:
355369
if not self.has_scaleable_traces():
356370
if self._dtype.kind == "f":
357371
# here we do not truely have scale but we assume this is scaled
358372
# this helps a lot for simulated data
359373
pass
360374
else:
361375
raise ValueError(
362-
"This recording does not support return_scaled=True (need gain_to_uV and offset_"
376+
"This recording does not support return_in_uV=True (need gain_to_uV and offset_"
363377
"to_uV properties)"
364378
)
365379
else:

src/spikeinterface/core/basesnippets.py

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,51 @@ def get_snippets(
9797
indices=None,
9898
segment_index: Union[int, None] = None,
9999
channel_ids: Union[list, None] = None,
100-
return_scaled=False,
100+
return_scaled: bool | None = None,
101+
return_in_uV: bool = False,
101102
):
103+
"""
104+
Return the snippets, optionally for a subset of samples and/or channels
105+
106+
Parameters
107+
----------
108+
indices : list[int], default: None
109+
Indices of the snippets to return. If None, all snippets are returned.
110+
segment_index : Union[int, None], default: None
111+
The segment index to get snippets from. If snippets is multi-segment, it is required.
112+
channel_ids : Union[list, None], default: None
113+
The channel ids. If None, all channels are used.
114+
return_scaled : bool | None, default: None
115+
DEPRECATED. Use return_in_uV instead.
116+
If True and the snippets has scaling (gain_to_uV and offset_to_uV properties),
117+
snippets are scaled to uV
118+
return_in_uV : bool, default: False
119+
If True and the snippets has scaling (gain_to_uV and offset_to_uV properties),
120+
snippets are scaled to uV
121+
122+
Returns
123+
-------
124+
np.array
125+
The snippets (num_snippets, num_samples, num_channels)
126+
"""
102127
segment_index = self._check_segment_index(segment_index)
103128
spts = self._snippets_segments[segment_index]
104129
channel_indices = self.ids_to_indices(channel_ids, prefer_slice=True)
105130
wfs = spts.get_snippets(indices, channel_indices=channel_indices)
106131

107-
if return_scaled:
132+
# Handle deprecated return_scaled parameter
133+
if return_scaled is not None:
134+
warn(
135+
"`return_scaled` is deprecated and will be removed in a future version. Use `return_in_uV` instead.",
136+
category=DeprecationWarning,
137+
stacklevel=2,
138+
)
139+
return_in_uV = return_scaled
140+
141+
if return_in_uV:
108142
if not self.has_scaleable_traces():
109143
raise ValueError(
110-
"These snippets do not support return_scaled=True (need gain_to_uV and offset_" "to_uV properties)"
144+
"These snippets do not support return_in_uV=True (need gain_to_uV and offset_" "to_uV properties)"
111145
)
112146
else:
113147
gains = self.get_property("gain_to_uV")
@@ -123,13 +157,49 @@ def get_snippets_from_frames(
123157
start_frame: Union[int, None] = None,
124158
end_frame: Union[int, None] = None,
125159
channel_ids: Union[list, None] = None,
126-
return_scaled=False,
160+
return_scaled: bool | None = None,
161+
return_in_uV: bool = False,
127162
):
163+
"""
164+
Return the snippets from frames, optionally for a subset of samples and/or channels
165+
166+
Parameters
167+
----------
168+
segment_index : Union[int, None], default: None
169+
The segment index to get snippets from. If snippets is multi-segment, it is required.
170+
start_frame : Union[int, None], default: None
171+
The start frame. If None, 0 is used.
172+
end_frame : Union[int, None], default: None
173+
The end frame. If None, the number of samples in the segment is used.
174+
channel_ids : Union[list, None], default: None
175+
The channel ids. If None, all channels are used.
176+
return_scaled : bool | None, default: None
177+
DEPRECATED. Use return_in_uV instead.
178+
If True and the snippets has scaling (gain_to_uV and offset_to_uV properties),
179+
snippets are scaled to uV
180+
return_in_uV : bool, default: False
181+
If True and the snippets has scaling (gain_to_uV and offset_to_uV properties),
182+
snippets are scaled to uV
183+
184+
Returns
185+
-------
186+
np.array
187+
The snippets (num_snippets, num_samples, num_channels)
188+
"""
128189
segment_index = self._check_segment_index(segment_index)
129190
spts = self._snippets_segments[segment_index]
130191
indices = spts.frames_to_indices(start_frame, end_frame)
131192

132-
return self.get_snippets(indices, channel_ids=channel_ids, return_scaled=return_scaled)
193+
# Handle deprecated return_scaled parameter
194+
if return_scaled is not None:
195+
warn(
196+
"`return_scaled` is deprecated and will be removed in a future version. Use `return_in_uV` instead.",
197+
category=DeprecationWarning,
198+
stacklevel=2,
199+
)
200+
return_in_uV = return_scaled
201+
202+
return self.get_snippets(indices, channel_ids=channel_ids, return_in_uV=return_in_uV)
133203

134204
def _save(self, format="binary", **save_kwargs):
135205
raise NotImplementedError

src/spikeinterface/core/frameslicerecording.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,18 @@ class FrameSliceRecording(BaseRecording):
2727
def __init__(self, parent_recording, start_frame=None, end_frame=None):
2828
channel_ids = parent_recording.get_channel_ids()
2929

30-
assert parent_recording.get_num_segments() == 1, "FrameSliceRecording only works with one segment"
30+
num_segments = parent_recording.get_num_segments()
31+
assert num_segments == 1, f"FrameSliceRecording only works with one segment but found {num_segments}"
3132

32-
parent_size = parent_recording.get_num_samples(segment_index=0)
33-
if start_frame is None:
34-
start_frame = 0
35-
else:
36-
assert 0 <= start_frame < parent_size
37-
38-
if end_frame is None:
39-
end_frame = parent_size
40-
else:
41-
assert (
42-
0 < end_frame <= parent_size
43-
), f"'end_frame' must be fewer than number of samples in parent: {parent_size}"
33+
samples_in_recording = parent_recording.get_num_samples(segment_index=0)
34+
start_frame = start_frame or 0
35+
end_frame = end_frame or samples_in_recording
4436

45-
assert end_frame > start_frame, "'start_frame' must be smaller than 'end_frame'!"
37+
assert start_frame >= 0, f"{start_frame=} must be positive"
38+
assert start_frame < end_frame, f"{start_frame=} must be smaller than 'end_frame' {end_frame=}!"
39+
assert (
40+
end_frame <= samples_in_recording
41+
), f"{end_frame=} must be smaller than or equal to {samples_in_recording=}"
4642

4743
BaseRecording.__init__(
4844
self,
@@ -53,7 +49,7 @@ def __init__(self, parent_recording, start_frame=None, end_frame=None):
5349

5450
# link recording segment
5551
parent_segment = parent_recording._recording_segments[0]
56-
sub_segment = FrameSliceRecordingSegment(parent_segment, int(start_frame), int(end_frame))
52+
sub_segment = FrameSliceRecordingSegment(parent_segment, start_frame=int(start_frame), end_frame=int(end_frame))
5753
self.add_recording_segment(sub_segment)
5854

5955
# copy properties and annotations

0 commit comments

Comments
 (0)