Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,14 @@
get_best_job_kwargs,
ensure_n_jobs,
ensure_chunk_size,
ChunkExecutor,
TimeSeriesChunkExecutor,
split_job_kwargs,
fix_job_kwargs,
)
from .recording_tools import (
write_binary_recording,
write_memory_recording,
write_recording_to_zarr,
write_to_h5_dataset_format,
get_random_data_chunks,
get_channel_distances,
Expand Down
12 changes: 6 additions & 6 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import numpy as np
from probeinterface import read_probeinterface, write_probeinterface

from .chunkable import ChunkableSegment, ChunkableMixin
from .time_series import TimeSeriesSegment, TimeSeries
from .baserecordingsnippets import BaseRecordingSnippets
from .core_tools import convert_bytes_to_str, convert_seconds_to_str
from .job_tools import split_job_kwargs


class BaseRecording(BaseRecordingSnippets, ChunkableMixin):
class BaseRecording(BaseRecordingSnippets, TimeSeries):
"""
Abstract class representing several a multichannel timeseries (or block of raw ephys traces).
Internally handle list of RecordingSegment
Expand Down Expand Up @@ -300,7 +300,7 @@ def get_traces(

def get_data(self, start_frame: int, end_frame: int, segment_index: int | None = None, **kwargs) -> np.ndarray:
"""
General retrieval function for chunkable objects
General retrieval function for time_series objects
"""
return self.get_traces(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame, **kwargs)

Expand All @@ -311,7 +311,7 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs):
kwargs, job_kwargs = split_job_kwargs(save_kwargs)

if format == "binary":
from .chunkable_tools import write_binary
from .time_series_tools import write_binary

folder = kwargs["folder"]
file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())]
Expand Down Expand Up @@ -637,7 +637,7 @@ def astype(self, dtype, round: bool | None = None):
return astype(self, dtype=dtype, round=round)


class BaseRecordingSegment(ChunkableSegment):
class BaseRecordingSegment(TimeSeriesSegment):
"""
Abstract class representing a multichannel timeseries, or block of raw ephys traces
"""
Expand Down Expand Up @@ -672,6 +672,6 @@ def get_data(
self, start_frame: int, end_frame: int, indices: list | np.ndarray | tuple | None = None
) -> np.ndarray:
"""
General retrieval function for chunkable objects
General retrieval function for time_series objects
"""
return self.get_traces(start_frame=start_frame, end_frame=end_frame, channel_indices=indices)
46 changes: 23 additions & 23 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def divide_segment_into_chunks(num_frames, chunk_size):
return chunks


def divide_chunkable_into_chunks(recording, chunk_size):
def divide_time_series_into_chunks(recording, chunk_size):
slices = []
for segment_index in range(recording.get_num_segments()):
num_frames = recording.get_num_samples(segment_index)
Expand Down Expand Up @@ -242,24 +242,24 @@ def ensure_n_jobs(extractor, n_jobs=1):
return n_jobs


def chunk_duration_to_chunk_size(chunk_duration, chunkable: "ChunkableMixin"):
def chunk_duration_to_chunk_size(chunk_duration, time_series: "TimeSeries"):
if isinstance(chunk_duration, float):
chunk_size = int(chunk_duration * chunkable.get_sampling_frequency())
chunk_size = int(chunk_duration * time_series.get_sampling_frequency())
elif isinstance(chunk_duration, str):
if chunk_duration.endswith("ms"):
chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0
elif chunk_duration.endswith("s"):
chunk_duration = float(chunk_duration.replace("s", ""))
else:
raise ValueError("chunk_duration must ends with s or ms")
chunk_size = int(chunk_duration * chunkable.get_sampling_frequency())
chunk_size = int(chunk_duration * time_series.get_sampling_frequency())
else:
raise ValueError("chunk_duration must be str or float")
return chunk_size


def ensure_chunk_size(
chunkable: "ChunkableMixin",
time_series: "TimeSeries",
total_memory=None,
chunk_size=None,
chunk_memory=None,
Expand Down Expand Up @@ -299,30 +299,30 @@ def ensure_chunk_size(
assert total_memory is None
# set by memory per worker size
chunk_memory = convert_string_to_bytes(chunk_memory)
chunk_size = int(chunk_memory / chunkable.get_sample_size_in_bytes())
chunk_size = int(chunk_memory / time_series.get_sample_size_in_bytes())
elif total_memory is not None:
# clip by total memory size
n_jobs = ensure_n_jobs(chunkable, n_jobs=n_jobs)
n_jobs = ensure_n_jobs(time_series, n_jobs=n_jobs)
total_memory = convert_string_to_bytes(total_memory)
chunk_size = int(total_memory / (chunkable.get_sample_size_in_bytes() * n_jobs))
chunk_size = int(total_memory / (time_series.get_sample_size_in_bytes() * n_jobs))
elif chunk_duration is not None:
chunk_size = chunk_duration_to_chunk_size(chunk_duration, chunkable)
chunk_size = chunk_duration_to_chunk_size(chunk_duration, time_series)
else:
# Edge case to define single chunk per segment for n_jobs=1.
# All chunking parameters equal None mean single chunk per segment
if n_jobs == 1:
num_segments = chunkable.get_num_segments()
samples_in_larger_segment = max([chunkable.get_num_samples(segment) for segment in range(num_segments)])
num_segments = time_series.get_num_segments()
samples_in_larger_segment = max([time_series.get_num_samples(segment) for segment in range(num_segments)])
chunk_size = samples_in_larger_segment
else:
raise ValueError("For n_jobs >1 you must specify total_memory or chunk_size or chunk_memory")

return chunk_size


class ChunkExecutor:
class TimeSeriesChunkExecutor:
"""
Core class for parallel processing to run a "function" over chunks on a chunkable extractor.
Core class for parallel processing to run a "function" over chunks on a time_series extractor.

It supports running a function:
* in loop with chunk processing (low RAM usage)
Expand All @@ -334,8 +334,8 @@ class ChunkExecutor:

Parameters
----------
chunkable : ChunkableMixin
The chunkable object to be processed.
time_series : TimeSeries
The time_series object to be processed.
func : function
Function that runs on each chunk
init_func : function
Expand Down Expand Up @@ -383,7 +383,7 @@ class ChunkExecutor:

def __init__(
self,
chunkable: "ChunkableMixin",
time_series: "TimeSeries",
func,
init_func,
init_args,
Expand All @@ -402,7 +402,7 @@ def __init__(
max_threads_per_worker=1,
need_worker_index=False,
):
self.chunkable = chunkable
self.time_series = time_series
self.func = func
self.init_func = init_func
self.init_args = init_args
Expand All @@ -421,7 +421,7 @@ def __init__(
else:
mp_context = "spawn"

preferred_mp_context = chunkable.get_preferred_mp_context()
preferred_mp_context = time_series.get_preferred_mp_context()
if preferred_mp_context is not None and preferred_mp_context != mp_context:
warnings.warn(
f"Your processing chain using pool_engine='process' and mp_context='{mp_context}' is not possible."
Expand All @@ -437,7 +437,7 @@ def __init__(
self.handle_returns = handle_returns
self.gather_func = gather_func

self.n_jobs = ensure_n_jobs(self.chunkable, n_jobs=n_jobs)
self.n_jobs = ensure_n_jobs(self.time_series, n_jobs=n_jobs)
self.chunk_size = self.ensure_chunk_size(
total_memory=total_memory,
chunk_size=chunk_size,
Expand All @@ -455,7 +455,7 @@ def __init__(
if verbose:
chunk_memory = self.get_chunk_memory()
total_memory = chunk_memory * self.n_jobs
chunk_duration = self.chunk_size / chunkable.sampling_frequency
chunk_duration = self.chunk_size / time_series.sampling_frequency
chunk_memory_str = convert_bytes_to_str(chunk_memory)
total_memory_str = convert_bytes_to_str(total_memory)
chunk_duration_str = convert_seconds_to_str(chunk_duration)
Expand All @@ -471,13 +471,13 @@ def __init__(
)

def get_chunk_memory(self):
return self.chunk_size * self.chunkable.get_sample_size_in_bytes()
return self.chunk_size * self.time_series.get_sample_size_in_bytes()

def ensure_chunk_size(
self, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs
):
return ensure_chunk_size(
self.chunkable, total_memory, chunk_size, chunk_memory, chunk_duration, n_jobs, **other_kwargs
self.time_series, total_memory, chunk_size, chunk_memory, chunk_duration, n_jobs, **other_kwargs
)

def run(self, slices=None):
Expand All @@ -487,7 +487,7 @@ def run(self, slices=None):

if slices is None:
# TODO: rename
slices = divide_chunkable_into_chunks(self.chunkable, self.chunk_size)
slices = divide_time_series_into_chunks(self.time_series, self.chunk_size)

if self.handle_returns:
returns = []
Expand Down
32 changes: 16 additions & 16 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import numpy as np

from spikeinterface.core.base import base_peak_dtype, spike_peak_dtype
from spikeinterface.core.chunkable import ChunkableMixin
from spikeinterface.core.time_series import TimeSeries
from spikeinterface.core import BaseRecording, get_chunk_with_margin
from spikeinterface.core.job_tools import ChunkExecutor, fix_job_kwargs, _shared_job_kwargs_doc
from spikeinterface.core.job_tools import TimeSeriesChunkExecutor, fix_job_kwargs, _shared_job_kwargs_doc
from spikeinterface.core import get_channel_distances
from spikeinterface.core.core_tools import ms_to_samples

Expand All @@ -26,7 +26,7 @@ class PipelineNode:

def __init__(
self,
chunkable: ChunkableMixin,
time_series: TimeSeries,
return_output: bool | tuple[bool] = True,
parents: list[Type["PipelineNode"]] | None = None,
):
Expand All @@ -38,16 +38,16 @@ def __init__(

Parameters
----------
chunkable : ChunkableMixin
The chunkable object.
time_series : TimeSeries
The time_series object.
return_output : bool or tuple[bool], default: True
Whether or not the output of the node is returned by the pipeline.
When a Node have several toutputs then this can be a tuple of bool
parents : list[PipelineNode] | None, default: None
Pass parents nodes to perform a previous computation.
"""

self.chunkable = chunkable
self.time_series = time_series
self.return_output = return_output
if isinstance(parents, str):
# only one parents is allowed
Expand Down Expand Up @@ -526,7 +526,7 @@ def check_graph(nodes, check_for_peak_source=True):


def run_node_pipeline(
chunkable: ChunkableMixin,
time_series: TimeSeries,
nodes: list[PipelineNode],
job_kwargs: dict,
job_name: str = "pipeline",
Expand Down Expand Up @@ -566,8 +566,8 @@ def run_node_pipeline(

Parameters
----------
chunkable: ChunkableMixin
The chunkable object to run the pipeline on. This is typically a recording but it can be anything that have the
time_series: TimeSeries
The time_series object to run the pipeline on. This is typically a recording but it can be anything that have the
same interface for getting chunks with margin.
nodes: a list of PipelineNode
The list of nodes to run in the pipeline. The order of the nodes is important as it defines
Expand Down Expand Up @@ -626,10 +626,10 @@ def run_node_pipeline(
# See need_first_call_before_pipeline : this trigger numba compilation before the run
node0._first_call_before_pipeline()

init_args = (chunkable, nodes, skip_after_n_peaks_per_worker)
init_args = (time_series, nodes, skip_after_n_peaks_per_worker)

processor = ChunkExecutor(
chunkable,
processor = TimeSeriesChunkExecutor(
time_series,
_compute_peak_pipeline_chunk,
_init_peak_pipeline,
init_args,
Expand All @@ -645,10 +645,10 @@ def run_node_pipeline(
return outs


def _init_peak_pipeline(chunkable, nodes, skip_after_n_peaks_per_worker):
def _init_peak_pipeline(time_series, nodes, skip_after_n_peaks_per_worker):
# create a local dict per worker
worker_ctx = {}
worker_ctx["chunkable"] = chunkable
worker_ctx["time_series"] = time_series
worker_ctx["nodes"] = nodes
worker_ctx["max_margin"] = max(node.get_margin() for node in nodes)
worker_ctx["skip_after_n_peaks_per_worker"] = skip_after_n_peaks_per_worker
Expand All @@ -657,12 +657,12 @@ def _init_peak_pipeline(chunkable, nodes, skip_after_n_peaks_per_worker):


def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_ctx):
chunkable = worker_ctx["chunkable"]
time_series = worker_ctx["time_series"]
max_margin = worker_ctx["max_margin"]
nodes = worker_ctx["nodes"]
skip_after_n_peaks_per_worker = worker_ctx["skip_after_n_peaks_per_worker"]

chunkable_segment = chunkable.segments[segment_index]
chunkable_segment = time_series.segments[segment_index]
retrievers = find_parents_of_type(nodes, (SpikeRetriever, PeakRetriever))
# get peak slices once for all retrievers
peak_slice_by_retriever = {}
Expand Down
Loading
Loading