Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,12 @@
get_best_job_kwargs,
ensure_n_jobs,
ensure_chunk_size,
TimeSeriesChunkExecutor,
ChunkExecutor,
split_job_kwargs,
fix_job_kwargs,
)
from .recording_tools import (
write_binary_recording,
write_memory_recording,
write_recording_to_zarr,
write_to_h5_dataset_format,
get_random_data_chunks,
get_channel_distances,
Expand Down
12 changes: 6 additions & 6 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import numpy as np
from probeinterface import read_probeinterface, write_probeinterface

from .time_series import TimeSeriesSegment, TimeSeries
from .chunkable import ChunkableSegment, ChunkableMixin
from .baserecordingsnippets import BaseRecordingSnippets
from .core_tools import convert_bytes_to_str, convert_seconds_to_str
from .job_tools import split_job_kwargs


class BaseRecording(BaseRecordingSnippets, TimeSeries):
class BaseRecording(BaseRecordingSnippets, ChunkableMixin):
"""
Abstract class representing several a multichannel timeseries (or block of raw ephys traces).
Internally handle list of RecordingSegment
Expand Down Expand Up @@ -305,7 +305,7 @@ def get_traces(

def get_data(self, start_frame: int, end_frame: int, segment_index: int | None = None, **kwargs) -> np.ndarray:
"""
General retrieval function for time_series objects
General retrieval function for chunkable objects
"""
return self.get_traces(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame, **kwargs)

Expand All @@ -316,7 +316,7 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs):
kwargs, job_kwargs = split_job_kwargs(save_kwargs)

if format == "binary":
from .time_series_tools import write_binary
from .chunkable_tools import write_binary

folder = kwargs["folder"]
file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())]
Expand Down Expand Up @@ -642,7 +642,7 @@ def astype(self, dtype, round: bool | None = None):
return astype(self, dtype=dtype, round=round)


class BaseRecordingSegment(TimeSeriesSegment):
class BaseRecordingSegment(ChunkableSegment):
"""
Abstract class representing a multichannel timeseries, or block of raw ephys traces
"""
Expand Down Expand Up @@ -677,6 +677,6 @@ def get_data(
self, start_frame: int, end_frame: int, indices: list | np.ndarray | tuple | None = None
) -> np.ndarray:
"""
General retrieval function for time_series objects
General retrieval function for chunkable objects
"""
return self.get_traces(start_frame=start_frame, end_frame=end_frame, channel_indices=indices)
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@
from spikeinterface.core.base import BaseExtractor, BaseSegment


class TimeSeries(ABC):
class ChunkableMixin(ABC):
"""
Abstract base class for time series extractors: continuous data sampled along a time axis
that supports chunked access for parallelization. The class can only be used by extractors
that inherit from BaseExtractor.
Abstract mixin class for chunkable objects. Note that the mixin can only be used
for classes that inherit from BaseExtractor.
Provides methods to handle chunked data access, that can be used for parallelization.
In addition, since chunkable objects are continuous data, time handling methods are provided.

Provides the chunking contract (``get_data``, ``get_shape``, ``get_sample_size_in_bytes``,
memory-size helpers, multiprocessing hints) and time-handling methods built on top of it.
All abstract methods must be implemented in the child class.
The Mixin is abstract since all methods need to be implemented in the child class in order
for it to function properly.
"""

_preferred_mp_context = None

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not issubclass(cls, BaseExtractor):
raise TypeError(f"{cls.__name__} must inherit from BaseExtractor to use TimeSeries.")
raise TypeError(f"{cls.__name__} must inherit from BaseExtractor to use Chunkable mixin.")

@abstractmethod
def get_sampling_frequency(self) -> float:
Expand All @@ -45,13 +45,13 @@ def get_shape(self, segment_index: int | None = None) -> tuple[int, ...]:
def get_data(self, start_frame: int, end_frame: int, segment_index: int | None = None, **kwargs) -> np.ndarray:
raise NotImplementedError

def _extra_copy_metadata(self, other: "TimeSeries", **kwargs) -> None:
def _extra_copy_metadata(self, other: "ChunkableMixin", **kwargs) -> None:
"""
Copy metadata from another TimeSeries object.
Copy metadata from another Chunkable object.

Parameters
----------
other : TimeSeries
other : ChunkableMixin
The object from which to copy metadata.
"""
# inherit preferred mp context if any
Expand Down Expand Up @@ -362,9 +362,8 @@ def _get_time_vectors(self):
return time_vectors


class TimeSeriesSegment(BaseSegment):
"""Per-segment time-series class. Provides time handling methods (sample/time conversion,
start/end time, time vectors) on top of ``BaseSegment``."""
class ChunkableSegment(BaseSegment):
"""Class for chunkable segments, which provide methods to handle time kwargs."""

def __init__(self, sampling_frequency=None, t_start=None, time_vector=None):
# sampling_frequency and time_vector are exclusive
Expand Down
Loading
Loading