Skip to content

Commit 409d1f5

Browse files
Improvements over chuking refactoring (#4533)
Co-authored-by: Alessio Buccino <alejoe9187@gmail.com>
1 parent 651c8ea commit 409d1f5

17 files changed

Lines changed: 337 additions & 179 deletions

src/spikeinterface/core/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,14 @@
9797
get_best_job_kwargs,
9898
ensure_n_jobs,
9999
ensure_chunk_size,
100-
ChunkExecutor,
100+
TimeSeriesChunkExecutor,
101101
split_job_kwargs,
102102
fix_job_kwargs,
103103
)
104104
from .recording_tools import (
105105
write_binary_recording,
106+
write_memory_recording,
107+
write_recording_to_zarr,
106108
write_to_h5_dataset_format,
107109
get_random_data_chunks,
108110
get_channel_distances,

src/spikeinterface/core/baserecording.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
import numpy as np
66
from probeinterface import read_probeinterface, write_probeinterface
77

8-
from .chunkable import ChunkableSegment, ChunkableMixin
8+
from .time_series import TimeSeriesSegment, TimeSeries
99
from .baserecordingsnippets import BaseRecordingSnippets
1010
from .core_tools import convert_bytes_to_str, convert_seconds_to_str
1111
from .job_tools import split_job_kwargs
1212

1313

14-
class BaseRecording(BaseRecordingSnippets, ChunkableMixin):
14+
class BaseRecording(BaseRecordingSnippets, TimeSeries):
1515
"""
1616
Abstract class representing several a multichannel timeseries (or block of raw ephys traces).
1717
Internally handle list of RecordingSegment
@@ -305,7 +305,7 @@ def get_traces(
305305

306306
def get_data(self, start_frame: int, end_frame: int, segment_index: int | None = None, **kwargs) -> np.ndarray:
307307
"""
308-
General retrieval function for chunkable objects
308+
General retrieval function for time_series objects
309309
"""
310310
return self.get_traces(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame, **kwargs)
311311

@@ -316,7 +316,7 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs):
316316
kwargs, job_kwargs = split_job_kwargs(save_kwargs)
317317

318318
if format == "binary":
319-
from .chunkable_tools import write_binary
319+
from .time_series_tools import write_binary
320320

321321
folder = kwargs["folder"]
322322
file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())]
@@ -642,7 +642,7 @@ def astype(self, dtype, round: bool | None = None):
642642
return astype(self, dtype=dtype, round=round)
643643

644644

645-
class BaseRecordingSegment(ChunkableSegment):
645+
class BaseRecordingSegment(TimeSeriesSegment):
646646
"""
647647
Abstract class representing a multichannel timeseries, or block of raw ephys traces
648648
"""
@@ -677,6 +677,6 @@ def get_data(
677677
self, start_frame: int, end_frame: int, indices: list | np.ndarray | tuple | None = None
678678
) -> np.ndarray:
679679
"""
680-
General retrieval function for chunkable objects
680+
General retrieval function for time_series objects
681681
"""
682682
return self.get_traces(start_frame=start_frame, end_frame=end_frame, channel_indices=indices)

src/spikeinterface/core/job_tools.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def divide_segment_into_chunks(num_frames, chunk_size):
205205
return chunks
206206

207207

208-
def divide_chunkable_into_chunks(recording, chunk_size):
208+
def divide_time_series_into_chunks(recording, chunk_size):
209209
slices = []
210210
for segment_index in range(recording.get_num_segments()):
211211
num_frames = recording.get_num_samples(segment_index)
@@ -242,24 +242,24 @@ def ensure_n_jobs(extractor, n_jobs=1):
242242
return n_jobs
243243

244244

245-
def chunk_duration_to_chunk_size(chunk_duration, chunkable: "ChunkableMixin"):
245+
def chunk_duration_to_chunk_size(chunk_duration, time_series: "TimeSeries"):
246246
if isinstance(chunk_duration, float):
247-
chunk_size = int(chunk_duration * chunkable.get_sampling_frequency())
247+
chunk_size = int(chunk_duration * time_series.get_sampling_frequency())
248248
elif isinstance(chunk_duration, str):
249249
if chunk_duration.endswith("ms"):
250250
chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0
251251
elif chunk_duration.endswith("s"):
252252
chunk_duration = float(chunk_duration.replace("s", ""))
253253
else:
254254
raise ValueError("chunk_duration must ends with s or ms")
255-
chunk_size = int(chunk_duration * chunkable.get_sampling_frequency())
255+
chunk_size = int(chunk_duration * time_series.get_sampling_frequency())
256256
else:
257257
raise ValueError("chunk_duration must be str or float")
258258
return chunk_size
259259

260260

261261
def ensure_chunk_size(
262-
chunkable: "ChunkableMixin",
262+
time_series: "TimeSeries",
263263
total_memory=None,
264264
chunk_size=None,
265265
chunk_memory=None,
@@ -299,30 +299,30 @@ def ensure_chunk_size(
299299
assert total_memory is None
300300
# set by memory per worker size
301301
chunk_memory = convert_string_to_bytes(chunk_memory)
302-
chunk_size = int(chunk_memory / chunkable.get_sample_size_in_bytes())
302+
chunk_size = int(chunk_memory / time_series.get_sample_size_in_bytes())
303303
elif total_memory is not None:
304304
# clip by total memory size
305-
n_jobs = ensure_n_jobs(chunkable, n_jobs=n_jobs)
305+
n_jobs = ensure_n_jobs(time_series, n_jobs=n_jobs)
306306
total_memory = convert_string_to_bytes(total_memory)
307-
chunk_size = int(total_memory / (chunkable.get_sample_size_in_bytes() * n_jobs))
307+
chunk_size = int(total_memory / (time_series.get_sample_size_in_bytes() * n_jobs))
308308
elif chunk_duration is not None:
309-
chunk_size = chunk_duration_to_chunk_size(chunk_duration, chunkable)
309+
chunk_size = chunk_duration_to_chunk_size(chunk_duration, time_series)
310310
else:
311311
# Edge case to define single chunk per segment for n_jobs=1.
312312
# All chunking parameters equal None mean single chunk per segment
313313
if n_jobs == 1:
314-
num_segments = chunkable.get_num_segments()
315-
samples_in_larger_segment = max([chunkable.get_num_samples(segment) for segment in range(num_segments)])
314+
num_segments = time_series.get_num_segments()
315+
samples_in_larger_segment = max([time_series.get_num_samples(segment) for segment in range(num_segments)])
316316
chunk_size = samples_in_larger_segment
317317
else:
318318
raise ValueError("For n_jobs >1 you must specify total_memory or chunk_size or chunk_memory")
319319

320320
return chunk_size
321321

322322

323-
class ChunkExecutor:
323+
class TimeSeriesChunkExecutor:
324324
"""
325-
Core class for parallel processing to run a "function" over chunks on a chunkable extractor.
325+
Core class for parallel processing to run a "function" over chunks on a time_series extractor.
326326
327327
It supports running a function:
328328
* in loop with chunk processing (low RAM usage)
@@ -334,8 +334,8 @@ class ChunkExecutor:
334334
335335
Parameters
336336
----------
337-
chunkable : ChunkableMixin
338-
The chunkable object to be processed.
337+
time_series : TimeSeries
338+
The time_series object to be processed.
339339
func : function
340340
Function that runs on each chunk
341341
init_func : function
@@ -383,7 +383,7 @@ class ChunkExecutor:
383383

384384
def __init__(
385385
self,
386-
chunkable: "ChunkableMixin",
386+
time_series: "TimeSeries",
387387
func,
388388
init_func,
389389
init_args,
@@ -402,7 +402,7 @@ def __init__(
402402
max_threads_per_worker=1,
403403
need_worker_index=False,
404404
):
405-
self.chunkable = chunkable
405+
self.time_series = time_series
406406
self.func = func
407407
self.init_func = init_func
408408
self.init_args = init_args
@@ -421,7 +421,7 @@ def __init__(
421421
else:
422422
mp_context = "spawn"
423423

424-
preferred_mp_context = chunkable.get_preferred_mp_context()
424+
preferred_mp_context = time_series.get_preferred_mp_context()
425425
if preferred_mp_context is not None and preferred_mp_context != mp_context:
426426
warnings.warn(
427427
f"Your processing chain using pool_engine='process' and mp_context='{mp_context}' is not possible."
@@ -437,7 +437,7 @@ def __init__(
437437
self.handle_returns = handle_returns
438438
self.gather_func = gather_func
439439

440-
self.n_jobs = ensure_n_jobs(self.chunkable, n_jobs=n_jobs)
440+
self.n_jobs = ensure_n_jobs(self.time_series, n_jobs=n_jobs)
441441
self.chunk_size = self.ensure_chunk_size(
442442
total_memory=total_memory,
443443
chunk_size=chunk_size,
@@ -455,7 +455,7 @@ def __init__(
455455
if verbose:
456456
chunk_memory = self.get_chunk_memory()
457457
total_memory = chunk_memory * self.n_jobs
458-
chunk_duration = self.chunk_size / chunkable.sampling_frequency
458+
chunk_duration = self.chunk_size / time_series.sampling_frequency
459459
chunk_memory_str = convert_bytes_to_str(chunk_memory)
460460
total_memory_str = convert_bytes_to_str(total_memory)
461461
chunk_duration_str = convert_seconds_to_str(chunk_duration)
@@ -471,13 +471,13 @@ def __init__(
471471
)
472472

473473
def get_chunk_memory(self):
474-
return self.chunk_size * self.chunkable.get_sample_size_in_bytes()
474+
return self.chunk_size * self.time_series.get_sample_size_in_bytes()
475475

476476
def ensure_chunk_size(
477477
self, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs
478478
):
479479
return ensure_chunk_size(
480-
self.chunkable, total_memory, chunk_size, chunk_memory, chunk_duration, n_jobs, **other_kwargs
480+
self.time_series, total_memory, chunk_size, chunk_memory, chunk_duration, n_jobs, **other_kwargs
481481
)
482482

483483
def run(self, slices=None):
@@ -487,7 +487,7 @@ def run(self, slices=None):
487487

488488
if slices is None:
489489
# TODO: rename
490-
slices = divide_chunkable_into_chunks(self.chunkable, self.chunk_size)
490+
slices = divide_time_series_into_chunks(self.time_series, self.chunk_size)
491491

492492
if self.handle_returns:
493493
returns = []

src/spikeinterface/core/node_pipeline.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
import numpy as np
1212

1313
from spikeinterface.core.base import base_peak_dtype, spike_peak_dtype
14-
from spikeinterface.core.chunkable import ChunkableMixin
14+
from spikeinterface.core.time_series import TimeSeries
1515
from spikeinterface.core import BaseRecording, get_chunk_with_margin
16-
from spikeinterface.core.job_tools import ChunkExecutor, fix_job_kwargs, _shared_job_kwargs_doc
16+
from spikeinterface.core.job_tools import TimeSeriesChunkExecutor, fix_job_kwargs, _shared_job_kwargs_doc
1717
from spikeinterface.core import get_channel_distances
1818
from spikeinterface.core.core_tools import ms_to_samples
1919

@@ -26,7 +26,7 @@ class PipelineNode:
2626

2727
def __init__(
2828
self,
29-
chunkable: ChunkableMixin,
29+
time_series: TimeSeries,
3030
return_output: bool | tuple[bool] = True,
3131
parents: list[Type["PipelineNode"]] | None = None,
3232
):
@@ -38,16 +38,16 @@ def __init__(
3838
3939
Parameters
4040
----------
41-
chunkable : ChunkableMixin
42-
The chunkable object.
41+
time_series : TimeSeries
42+
The time_series object.
4343
return_output : bool or tuple[bool], default: True
4444
Whether or not the output of the node is returned by the pipeline.
4545
When a Node have several toutputs then this can be a tuple of bool
4646
parents : list[PipelineNode] | None, default: None
4747
Pass parents nodes to perform a previous computation.
4848
"""
4949

50-
self.chunkable = chunkable
50+
self.time_series = time_series
5151
self.return_output = return_output
5252
if isinstance(parents, str):
5353
# only one parents is allowed
@@ -526,7 +526,7 @@ def check_graph(nodes, check_for_peak_source=True):
526526

527527

528528
def run_node_pipeline(
529-
chunkable: ChunkableMixin,
529+
time_series: TimeSeries,
530530
nodes: list[PipelineNode],
531531
job_kwargs: dict,
532532
job_name: str = "pipeline",
@@ -566,8 +566,8 @@ def run_node_pipeline(
566566
567567
Parameters
568568
----------
569-
chunkable: ChunkableMixin
570-
The chunkable object to run the pipeline on. This is typically a recording but it can be anything that have the
569+
time_series: TimeSeries
570+
The time_series object to run the pipeline on. This is typically a recording but it can be anything that have the
571571
same interface for getting chunks with margin.
572572
nodes: a list of PipelineNode
573573
The list of nodes to run in the pipeline. The order of the nodes is important as it defines
@@ -626,10 +626,10 @@ def run_node_pipeline(
626626
# See need_first_call_before_pipeline : this trigger numba compilation before the run
627627
node0._first_call_before_pipeline()
628628

629-
init_args = (chunkable, nodes, skip_after_n_peaks_per_worker)
629+
init_args = (time_series, nodes, skip_after_n_peaks_per_worker)
630630

631-
processor = ChunkExecutor(
632-
chunkable,
631+
processor = TimeSeriesChunkExecutor(
632+
time_series,
633633
_compute_peak_pipeline_chunk,
634634
_init_peak_pipeline,
635635
init_args,
@@ -645,10 +645,10 @@ def run_node_pipeline(
645645
return outs
646646

647647

648-
def _init_peak_pipeline(chunkable, nodes, skip_after_n_peaks_per_worker):
648+
def _init_peak_pipeline(time_series, nodes, skip_after_n_peaks_per_worker):
649649
# create a local dict per worker
650650
worker_ctx = {}
651-
worker_ctx["chunkable"] = chunkable
651+
worker_ctx["time_series"] = time_series
652652
worker_ctx["nodes"] = nodes
653653
worker_ctx["max_margin"] = max(node.get_margin() for node in nodes)
654654
worker_ctx["skip_after_n_peaks_per_worker"] = skip_after_n_peaks_per_worker
@@ -657,12 +657,12 @@ def _init_peak_pipeline(chunkable, nodes, skip_after_n_peaks_per_worker):
657657

658658

659659
def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_ctx):
660-
chunkable = worker_ctx["chunkable"]
660+
time_series = worker_ctx["time_series"]
661661
max_margin = worker_ctx["max_margin"]
662662
nodes = worker_ctx["nodes"]
663663
skip_after_n_peaks_per_worker = worker_ctx["skip_after_n_peaks_per_worker"]
664664

665-
chunkable_segment = chunkable.segments[segment_index]
665+
chunkable_segment = time_series.segments[segment_index]
666666
retrievers = find_parents_of_type(nodes, (SpikeRetriever, PeakRetriever))
667667
# get peak slices once for all retrievers
668668
peak_slice_by_retriever = {}

0 commit comments

Comments
 (0)