Skip to content

Commit 693a306

Browse files
committed
Improvements over chuking refactoring
1 parent 10834d4 commit 693a306

17 files changed

Lines changed: 251 additions & 93 deletions

src/spikeinterface/core/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,14 @@
9696
get_best_job_kwargs,
9797
ensure_n_jobs,
9898
ensure_chunk_size,
99-
ChunkExecutor,
99+
TimeSeriesChunkExecutor,
100100
split_job_kwargs,
101101
fix_job_kwargs,
102102
)
103103
from .recording_tools import (
104104
write_binary_recording,
105+
write_memory_recording,
106+
write_recording_to_zarr,
105107
write_to_h5_dataset_format,
106108
get_random_data_chunks,
107109
get_channel_distances,

src/spikeinterface/core/baserecording.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
import numpy as np
66
from probeinterface import read_probeinterface, write_probeinterface
77

8-
from .chunkable import ChunkableSegment, ChunkableMixin
8+
from .time_series import TimeSeriesSegment, TimeSeries
99
from .baserecordingsnippets import BaseRecordingSnippets
1010
from .core_tools import convert_bytes_to_str, convert_seconds_to_str
1111
from .job_tools import split_job_kwargs
1212

1313

14-
class BaseRecording(BaseRecordingSnippets, ChunkableMixin):
14+
class BaseRecording(BaseRecordingSnippets, TimeSeries):
1515
"""
1616
Abstract class representing several a multichannel timeseries (or block of raw ephys traces).
1717
Internally handle list of RecordingSegment
@@ -311,7 +311,7 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs):
311311
kwargs, job_kwargs = split_job_kwargs(save_kwargs)
312312

313313
if format == "binary":
314-
from .chunkable_tools import write_binary
314+
from .time_series_tools import write_binary
315315

316316
folder = kwargs["folder"]
317317
file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())]
@@ -637,7 +637,7 @@ def astype(self, dtype, round: bool | None = None):
637637
return astype(self, dtype=dtype, round=round)
638638

639639

640-
class BaseRecordingSegment(ChunkableSegment):
640+
class BaseRecordingSegment(TimeSeriesSegment):
641641
"""
642642
Abstract class representing a multichannel timeseries, or block of raw ephys traces
643643
"""

src/spikeinterface/core/job_tools.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def divide_segment_into_chunks(num_frames, chunk_size):
205205
return chunks
206206

207207

208-
def divide_chunkable_into_chunks(recording, chunk_size):
208+
def divide_time_series_into_chunks(recording, chunk_size):
209209
slices = []
210210
for segment_index in range(recording.get_num_segments()):
211211
num_frames = recording.get_num_samples(segment_index)
@@ -242,7 +242,7 @@ def ensure_n_jobs(extractor, n_jobs=1):
242242
return n_jobs
243243

244244

245-
def chunk_duration_to_chunk_size(chunk_duration, chunkable: "ChunkableMixin"):
245+
def chunk_duration_to_chunk_size(chunk_duration, chunkable: "TimeSeries"):
246246
if isinstance(chunk_duration, float):
247247
chunk_size = int(chunk_duration * chunkable.get_sampling_frequency())
248248
elif isinstance(chunk_duration, str):
@@ -259,7 +259,7 @@ def chunk_duration_to_chunk_size(chunk_duration, chunkable: "ChunkableMixin"):
259259

260260

261261
def ensure_chunk_size(
262-
chunkable: "ChunkableMixin",
262+
chunkable: "TimeSeries",
263263
total_memory=None,
264264
chunk_size=None,
265265
chunk_memory=None,
@@ -320,7 +320,7 @@ def ensure_chunk_size(
320320
return chunk_size
321321

322322

323-
class ChunkExecutor:
323+
class TimeSeriesChunkExecutor:
324324
"""
325325
Core class for parallel processing to run a "function" over chunks on a chunkable extractor.
326326
@@ -334,7 +334,7 @@ class ChunkExecutor:
334334
335335
Parameters
336336
----------
337-
chunkable : ChunkableMixin
337+
chunkable : TimeSeries
338338
The chunkable object to be processed.
339339
func : function
340340
Function that runs on each chunk
@@ -383,7 +383,7 @@ class ChunkExecutor:
383383

384384
def __init__(
385385
self,
386-
chunkable: "ChunkableMixin",
386+
chunkable: "TimeSeries",
387387
func,
388388
init_func,
389389
init_args,
@@ -487,7 +487,7 @@ def run(self, slices=None):
487487

488488
if slices is None:
489489
# TODO: rename
490-
slices = divide_chunkable_into_chunks(self.chunkable, self.chunk_size)
490+
slices = divide_time_series_into_chunks(self.chunkable, self.chunk_size)
491491

492492
if self.handle_returns:
493493
returns = []

src/spikeinterface/core/node_pipeline.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
import numpy as np
1212

1313
from spikeinterface.core.base import base_peak_dtype, spike_peak_dtype
14-
from spikeinterface.core.chunkable import ChunkableMixin
14+
from spikeinterface.core.time_series import TimeSeries
1515
from spikeinterface.core import BaseRecording, get_chunk_with_margin
16-
from spikeinterface.core.job_tools import ChunkExecutor, fix_job_kwargs, _shared_job_kwargs_doc
16+
from spikeinterface.core.job_tools import TimeSeriesChunkExecutor, fix_job_kwargs, _shared_job_kwargs_doc
1717
from spikeinterface.core import get_channel_distances
1818
from spikeinterface.core.core_tools import ms_to_samples
1919

@@ -26,7 +26,7 @@ class PipelineNode:
2626

2727
def __init__(
2828
self,
29-
chunkable: ChunkableMixin,
29+
chunkable: TimeSeries,
3030
return_output: bool | tuple[bool] = True,
3131
parents: list[Type["PipelineNode"]] | None = None,
3232
):
@@ -38,7 +38,7 @@ def __init__(
3838
3939
Parameters
4040
----------
41-
chunkable : ChunkableMixin
41+
chunkable : TimeSeries
4242
The chunkable object.
4343
return_output : bool or tuple[bool], default: True
4444
Whether or not the output of the node is returned by the pipeline.
@@ -526,7 +526,7 @@ def check_graph(nodes, check_for_peak_source=True):
526526

527527

528528
def run_node_pipeline(
529-
chunkable: ChunkableMixin,
529+
chunkable: TimeSeries,
530530
nodes: list[PipelineNode],
531531
job_kwargs: dict,
532532
job_name: str = "pipeline",
@@ -566,7 +566,7 @@ def run_node_pipeline(
566566
567567
Parameters
568568
----------
569-
chunkable: ChunkableMixin
569+
chunkable: TimeSeries
570570
The chunkable object to run the pipeline on. This is typically a recording but it can be anything that have the
571571
same interface for getting chunks with margin.
572572
nodes: a list of PipelineNode
@@ -628,7 +628,7 @@ def run_node_pipeline(
628628

629629
init_args = (chunkable, nodes, skip_after_n_peaks_per_worker)
630630

631-
processor = ChunkExecutor(
631+
processor = TimeSeriesChunkExecutor(
632632
chunkable,
633633
_compute_peak_pipeline_chunk,
634634
_init_peak_pipeline,

src/spikeinterface/core/recording_tools.py

Lines changed: 159 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,169 @@
1313
ensure_chunk_size,
1414
divide_segment_into_chunks,
1515
fix_job_kwargs,
16-
ChunkExecutor,
16+
TimeSeriesChunkExecutor,
1717
_shared_job_kwargs_doc,
1818
split_job_kwargs,
1919
)
2020

21-
from .chunkable_tools import get_random_sample_slices, get_chunks, get_chunk_with_margin
21+
from .time_series_tools import get_random_sample_slices, get_chunks, get_chunk_with_margin
22+
from .time_series_tools import write_binary as _write_binary
23+
from .time_series_tools import write_memory as _write_memory
24+
from .time_series_tools import _write_time_series_to_zarr
2225

23-
# for back-compatibility imports
24-
from .chunkable_tools import write_binary as write_binary_recording
25-
from .chunkable_tools import write_memory as write_memory_recording
26+
27+
def write_binary_recording(
28+
recording,
29+
file_paths,
30+
file_timestamps_paths=None,
31+
dtype=None,
32+
add_file_extension=True,
33+
byte_offset=0,
34+
verbose=False,
35+
**job_kwargs,
36+
):
37+
"""
38+
Save the traces of a recording to binary format.
39+
40+
Parameters
41+
----------
42+
recording : BaseRecording
43+
The recording to save to binary file.
44+
file_paths : list[Path | str] | Path | str
45+
The path to the files to save data for each segment.
46+
file_timestamps_paths : list[Path | str] | Path | str | None, default: None
47+
The path to the timestamps file. If None, timestamps are not saved.
48+
dtype : dtype or None, default: None
49+
Type of the saved data.
50+
add_file_extension : bool, default: True
51+
If True, and the file path does not end in "raw", "bin", or "dat" then "raw" is added as an extension.
52+
byte_offset : int, default: 0
53+
Offset in bytes for the binary file (e.g. to write a header).
54+
verbose : bool, default: False
55+
Verbosity of the chunk executor.
56+
{}
57+
"""
58+
return _write_binary(
59+
recording,
60+
file_paths=file_paths,
61+
file_timestamps_paths=file_timestamps_paths,
62+
dtype=dtype,
63+
add_file_extension=add_file_extension,
64+
byte_offset=byte_offset,
65+
verbose=verbose,
66+
**job_kwargs,
67+
)
68+
69+
70+
write_binary_recording.__doc__ = write_binary_recording.__doc__.format(_shared_job_kwargs_doc)
71+
72+
73+
def write_memory_recording(
74+
recording,
75+
dtype=None,
76+
verbose=False,
77+
buffer_type="auto",
78+
job_name="write_memory",
79+
**job_kwargs,
80+
):
81+
"""
82+
Save the traces of a recording into numpy arrays in memory.
83+
84+
Uses SharedMemory when ``n_jobs > 1``.
85+
86+
Parameters
87+
----------
88+
recording : BaseRecording
89+
The recording to save to memory.
90+
dtype : dtype, default: None
91+
Type of the saved data.
92+
verbose : bool, default: False
93+
If True, output is verbose (when chunks are used).
94+
buffer_type : "auto" | "numpy" | "sharedmem", default: "auto"
95+
The type of buffer to use for storing the data.
96+
job_name : str, default: "write_memory"
97+
Name of the job.
98+
{}
99+
100+
Returns
101+
-------
102+
arrays : list
103+
One array per segment.
104+
"""
105+
return _write_memory(
106+
recording,
107+
dtype=dtype,
108+
verbose=verbose,
109+
buffer_type=buffer_type,
110+
job_name=job_name,
111+
**job_kwargs,
112+
)
113+
114+
115+
write_memory_recording.__doc__ = write_memory_recording.__doc__.format(_shared_job_kwargs_doc)
116+
117+
118+
def write_recording_to_zarr(
119+
recording,
120+
zarr_group,
121+
dataset_paths,
122+
dataset_timestamps_paths=None,
123+
extra_chunks=None,
124+
dtype=None,
125+
compressor_data=None,
126+
filters_data=None,
127+
compressor_times=None,
128+
filters_times=None,
129+
verbose=False,
130+
**job_kwargs,
131+
):
132+
"""
133+
Save the traces of a recording to zarr format.
134+
135+
Parameters
136+
----------
137+
recording : BaseRecording
138+
The recording to save in zarr format.
139+
zarr_group : zarr.Group
140+
The zarr group to add traces to.
141+
dataset_paths : list
142+
List of paths to traces datasets in the zarr group.
143+
dataset_timestamps_paths : list or None, default: None
144+
List of paths to timestamps datasets in the zarr group. If None, timestamps are not saved.
145+
extra_chunks : tuple or None, default: None
146+
Extra chunking dimensions to use for the zarr dataset. The first dimension is always time and
147+
controlled by the job_kwargs. Useful to chunk by channel, with ``extra_chunks=(channel_chunk_size,)``.
148+
dtype : dtype, default: None
149+
Type of the saved data.
150+
compressor_data : zarr compressor or None, default: None
151+
Zarr compressor for data.
152+
filters_data : list, default: None
153+
List of zarr filters for data.
154+
compressor_times : zarr compressor or None, default: None
155+
Zarr compressor for timestamps.
156+
filters_times : list, default: None
157+
List of zarr filters for timestamps.
158+
verbose : bool, default: False
159+
If True, output is verbose (when chunks are used).
160+
{}
161+
"""
162+
return _write_time_series_to_zarr(
163+
recording,
164+
zarr_group=zarr_group,
165+
dataset_paths=dataset_paths,
166+
dataset_timestamps_paths=dataset_timestamps_paths,
167+
extra_chunks=extra_chunks,
168+
dtype=dtype,
169+
compressor_data=compressor_data,
170+
filters_data=filters_data,
171+
compressor_times=compressor_times,
172+
filters_times=filters_times,
173+
verbose=verbose,
174+
**job_kwargs,
175+
)
176+
177+
178+
write_recording_to_zarr.__doc__ = write_recording_to_zarr.__doc__.format(_shared_job_kwargs_doc)
26179

27180

28181
def read_binary_recording(file, num_channels, dtype, time_axis=0, offset=0):
@@ -458,7 +611,7 @@ def append_noise_chunk(res):
458611
func = _noise_level_chunk
459612
init_func = _noise_level_chunk_init
460613
init_args = (recording, return_in_uV, method)
461-
executor = ChunkExecutor(
614+
executor = TimeSeriesChunkExecutor(
462615
recording,
463616
func,
464617
init_func,

0 commit comments

Comments
 (0)