Skip to content

Commit 37a40b9

Browse files
committed
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
2 parents e880394 + 409d1f5 commit 37a40b9

19 files changed

Lines changed: 460 additions & 179 deletions

src/spikeinterface/core/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,14 @@
9797
get_best_job_kwargs,
9898
ensure_n_jobs,
9999
ensure_chunk_size,
100-
ChunkExecutor,
100+
TimeSeriesChunkExecutor,
101101
split_job_kwargs,
102102
fix_job_kwargs,
103103
)
104104
from .recording_tools import (
105105
write_binary_recording,
106+
write_memory_recording,
107+
write_recording_to_zarr,
106108
write_to_h5_dataset_format,
107109
get_random_data_chunks,
108110
get_channel_distances,

src/spikeinterface/core/baserecording.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
import numpy as np
66
from probeinterface import read_probeinterface, write_probeinterface
77

8-
from .chunkable import ChunkableSegment, ChunkableMixin
8+
from .time_series import TimeSeriesSegment, TimeSeries
99
from .baserecordingsnippets import BaseRecordingSnippets
1010
from .core_tools import convert_bytes_to_str, convert_seconds_to_str
1111
from .job_tools import split_job_kwargs
1212

1313

14-
class BaseRecording(BaseRecordingSnippets, ChunkableMixin):
14+
class BaseRecording(BaseRecordingSnippets, TimeSeries):
1515
"""
1616
Abstract class representing several a multichannel timeseries (or block of raw ephys traces).
1717
Internally handle list of RecordingSegment
@@ -305,7 +305,7 @@ def get_traces(
305305

306306
def get_data(self, start_frame: int, end_frame: int, segment_index: int | None = None, **kwargs) -> np.ndarray:
307307
"""
308-
General retrieval function for chunkable objects
308+
General retrieval function for time_series objects
309309
"""
310310
return self.get_traces(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame, **kwargs)
311311

@@ -316,7 +316,7 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs):
316316
kwargs, job_kwargs = split_job_kwargs(save_kwargs)
317317

318318
if format == "binary":
319-
from .chunkable_tools import write_binary
319+
from .time_series_tools import write_binary
320320

321321
folder = kwargs["folder"]
322322
file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())]
@@ -642,7 +642,7 @@ def astype(self, dtype, round: bool | None = None):
642642
return astype(self, dtype=dtype, round=round)
643643

644644

645-
class BaseRecordingSegment(ChunkableSegment):
645+
class BaseRecordingSegment(TimeSeriesSegment):
646646
"""
647647
Abstract class representing a multichannel timeseries, or block of raw ephys traces
648648
"""
@@ -677,6 +677,6 @@ def get_data(
677677
self, start_frame: int, end_frame: int, indices: list | np.ndarray | tuple | None = None
678678
) -> np.ndarray:
679679
"""
680-
General retrieval function for chunkable objects
680+
General retrieval function for time_series objects
681681
"""
682682
return self.get_traces(start_frame=start_frame, end_frame=end_frame, channel_indices=indices)

src/spikeinterface/core/basesorting.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,11 @@ def register_recording(self, recording, check_spike_frames: bool = True):
327327
"Might be necessary for further postprocessing."
328328
)
329329
self._recording = recording
330+
# Copy the recording's start times into the sorting segments. This way,
331+
# the sorting preserves the start time even if the recording is later
332+
# detached (e.g. analyzer saved and reloaded without the recording).
333+
for segment_index, segment in enumerate(self.segments):
334+
segment._t_start = recording.get_start_time(segment_index=segment_index)
330335

331336
@property
332337
def sorting_info(self):
@@ -352,6 +357,66 @@ def has_time_vector(self, segment_index: int | None = None) -> bool:
352357
else:
353358
return False
354359

360+
def get_start_time(self, segment_index: int | None = None) -> float:
361+
"""Get the start time of the sorting segment.
362+
363+
Parameters
364+
----------
365+
segment_index : int or None, default: None
366+
The segment index (required for multi-segment)
367+
368+
Returns
369+
-------
370+
float
371+
The start time in seconds
372+
"""
373+
segment_index = self._check_segment_index(segment_index)
374+
segment = self.segments[segment_index]
375+
return segment._t_start if segment._t_start is not None else 0.0
376+
377+
def get_end_time(self, segment_index: int | None = None) -> float:
378+
"""Get the end time of the sorting segment.
379+
380+
If a recording is registered, returns the recording's end time.
381+
Otherwise returns the time of the last spike in the segment.
382+
383+
Parameters
384+
----------
385+
segment_index : int or None, default: None
386+
The segment index (required for multi-segment)
387+
388+
Returns
389+
-------
390+
float
391+
The end time in seconds
392+
"""
393+
segment_index = self._check_segment_index(segment_index)
394+
if self.has_recording():
395+
return self._recording.get_end_time(segment_index=segment_index)
396+
else:
397+
last_spike_frame = self.get_last_spike_frame(segment_index=segment_index)
398+
return self.sample_index_to_time(last_spike_frame, segment_index=segment_index)
399+
400+
def get_last_spike_frame(self, segment_index: int | None = None) -> int:
401+
"""Get the frame index of the last spike in a segment across all units.
402+
403+
Parameters
404+
----------
405+
segment_index : int or None, default: None
406+
The segment index (required for multi-segment)
407+
408+
Returns
409+
-------
410+
int
411+
The frame index of the last spike, or 0 if no spikes exist.
412+
"""
413+
segment_index = self._check_segment_index(segment_index)
414+
spike_vector = self.to_spike_vector(concatenated=False)
415+
spikes_in_segment = spike_vector[segment_index]
416+
if len(spikes_in_segment) == 0:
417+
return 0
418+
return int(np.max(spikes_in_segment["sample_index"]))
419+
355420
def get_times(self, segment_index=None):
356421
"""
357422
Get time vector for a registered recording segment.

src/spikeinterface/core/job_tools.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def divide_segment_into_chunks(num_frames, chunk_size):
205205
return chunks
206206

207207

208-
def divide_chunkable_into_chunks(recording, chunk_size):
208+
def divide_time_series_into_chunks(recording, chunk_size):
209209
slices = []
210210
for segment_index in range(recording.get_num_segments()):
211211
num_frames = recording.get_num_samples(segment_index)
@@ -242,24 +242,24 @@ def ensure_n_jobs(extractor, n_jobs=1):
242242
return n_jobs
243243

244244

245-
def chunk_duration_to_chunk_size(chunk_duration, chunkable: "ChunkableMixin"):
245+
def chunk_duration_to_chunk_size(chunk_duration, time_series: "TimeSeries"):
246246
if isinstance(chunk_duration, float):
247-
chunk_size = int(chunk_duration * chunkable.get_sampling_frequency())
247+
chunk_size = int(chunk_duration * time_series.get_sampling_frequency())
248248
elif isinstance(chunk_duration, str):
249249
if chunk_duration.endswith("ms"):
250250
chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0
251251
elif chunk_duration.endswith("s"):
252252
chunk_duration = float(chunk_duration.replace("s", ""))
253253
else:
254254
raise ValueError("chunk_duration must ends with s or ms")
255-
chunk_size = int(chunk_duration * chunkable.get_sampling_frequency())
255+
chunk_size = int(chunk_duration * time_series.get_sampling_frequency())
256256
else:
257257
raise ValueError("chunk_duration must be str or float")
258258
return chunk_size
259259

260260

261261
def ensure_chunk_size(
262-
chunkable: "ChunkableMixin",
262+
time_series: "TimeSeries",
263263
total_memory=None,
264264
chunk_size=None,
265265
chunk_memory=None,
@@ -299,30 +299,30 @@ def ensure_chunk_size(
299299
assert total_memory is None
300300
# set by memory per worker size
301301
chunk_memory = convert_string_to_bytes(chunk_memory)
302-
chunk_size = int(chunk_memory / chunkable.get_sample_size_in_bytes())
302+
chunk_size = int(chunk_memory / time_series.get_sample_size_in_bytes())
303303
elif total_memory is not None:
304304
# clip by total memory size
305-
n_jobs = ensure_n_jobs(chunkable, n_jobs=n_jobs)
305+
n_jobs = ensure_n_jobs(time_series, n_jobs=n_jobs)
306306
total_memory = convert_string_to_bytes(total_memory)
307-
chunk_size = int(total_memory / (chunkable.get_sample_size_in_bytes() * n_jobs))
307+
chunk_size = int(total_memory / (time_series.get_sample_size_in_bytes() * n_jobs))
308308
elif chunk_duration is not None:
309-
chunk_size = chunk_duration_to_chunk_size(chunk_duration, chunkable)
309+
chunk_size = chunk_duration_to_chunk_size(chunk_duration, time_series)
310310
else:
311311
# Edge case to define single chunk per segment for n_jobs=1.
312312
# All chunking parameters equal None mean single chunk per segment
313313
if n_jobs == 1:
314-
num_segments = chunkable.get_num_segments()
315-
samples_in_larger_segment = max([chunkable.get_num_samples(segment) for segment in range(num_segments)])
314+
num_segments = time_series.get_num_segments()
315+
samples_in_larger_segment = max([time_series.get_num_samples(segment) for segment in range(num_segments)])
316316
chunk_size = samples_in_larger_segment
317317
else:
318318
raise ValueError("For n_jobs >1 you must specify total_memory or chunk_size or chunk_memory")
319319

320320
return chunk_size
321321

322322

323-
class ChunkExecutor:
323+
class TimeSeriesChunkExecutor:
324324
"""
325-
Core class for parallel processing to run a "function" over chunks on a chunkable extractor.
325+
Core class for parallel processing to run a "function" over chunks on a time_series extractor.
326326
327327
It supports running a function:
328328
* in loop with chunk processing (low RAM usage)
@@ -334,8 +334,8 @@ class ChunkExecutor:
334334
335335
Parameters
336336
----------
337-
chunkable : ChunkableMixin
338-
The chunkable object to be processed.
337+
time_series : TimeSeries
338+
The time_series object to be processed.
339339
func : function
340340
Function that runs on each chunk
341341
init_func : function
@@ -383,7 +383,7 @@ class ChunkExecutor:
383383

384384
def __init__(
385385
self,
386-
chunkable: "ChunkableMixin",
386+
time_series: "TimeSeries",
387387
func,
388388
init_func,
389389
init_args,
@@ -402,7 +402,7 @@ def __init__(
402402
max_threads_per_worker=1,
403403
need_worker_index=False,
404404
):
405-
self.chunkable = chunkable
405+
self.time_series = time_series
406406
self.func = func
407407
self.init_func = init_func
408408
self.init_args = init_args
@@ -421,7 +421,7 @@ def __init__(
421421
else:
422422
mp_context = "spawn"
423423

424-
preferred_mp_context = chunkable.get_preferred_mp_context()
424+
preferred_mp_context = time_series.get_preferred_mp_context()
425425
if preferred_mp_context is not None and preferred_mp_context != mp_context:
426426
warnings.warn(
427427
f"Your processing chain using pool_engine='process' and mp_context='{mp_context}' is not possible."
@@ -437,7 +437,7 @@ def __init__(
437437
self.handle_returns = handle_returns
438438
self.gather_func = gather_func
439439

440-
self.n_jobs = ensure_n_jobs(self.chunkable, n_jobs=n_jobs)
440+
self.n_jobs = ensure_n_jobs(self.time_series, n_jobs=n_jobs)
441441
self.chunk_size = self.ensure_chunk_size(
442442
total_memory=total_memory,
443443
chunk_size=chunk_size,
@@ -455,7 +455,7 @@ def __init__(
455455
if verbose:
456456
chunk_memory = self.get_chunk_memory()
457457
total_memory = chunk_memory * self.n_jobs
458-
chunk_duration = self.chunk_size / chunkable.sampling_frequency
458+
chunk_duration = self.chunk_size / time_series.sampling_frequency
459459
chunk_memory_str = convert_bytes_to_str(chunk_memory)
460460
total_memory_str = convert_bytes_to_str(total_memory)
461461
chunk_duration_str = convert_seconds_to_str(chunk_duration)
@@ -471,13 +471,13 @@ def __init__(
471471
)
472472

473473
def get_chunk_memory(self):
474-
return self.chunk_size * self.chunkable.get_sample_size_in_bytes()
474+
return self.chunk_size * self.time_series.get_sample_size_in_bytes()
475475

476476
def ensure_chunk_size(
477477
self, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs
478478
):
479479
return ensure_chunk_size(
480-
self.chunkable, total_memory, chunk_size, chunk_memory, chunk_duration, n_jobs, **other_kwargs
480+
self.time_series, total_memory, chunk_size, chunk_memory, chunk_duration, n_jobs, **other_kwargs
481481
)
482482

483483
def run(self, slices=None):
@@ -487,7 +487,7 @@ def run(self, slices=None):
487487

488488
if slices is None:
489489
# TODO: rename
490-
slices = divide_chunkable_into_chunks(self.chunkable, self.chunk_size)
490+
slices = divide_time_series_into_chunks(self.time_series, self.chunk_size)
491491

492492
if self.handle_returns:
493493
returns = []

0 commit comments

Comments
 (0)