Skip to content

basesorting.py get_spike_trains method#3946

Closed
pas-calc wants to merge 2 commits into
SpikeInterface:mainfrom
pas-calc:patch-1
Closed

basesorting.py get_spike_trains method#3946
pas-calc wants to merge 2 commits into
SpikeInterface:mainfrom
pas-calc:patch-1

Conversation

@pas-calc
Copy link
Copy Markdown
Contributor

convienient function to get spike trains of all units as dict

an alternative method could be

spike_vector = sorting.to_spike_vector()
for unit_id in sorting.unit_ids:
    spikes = spike_vector[spike_vector["unit_index"]==sorting.id_to_index(unit_id)]
    # np.unique(spike_vector["unit_index"],return_counts=True) # number of spikes per unit
    ...

pas-calc and others added 2 commits May 22, 2025 16:07
convienient function to get spike trains of all units as dict
@chrishalcrow
Copy link
Copy Markdown
Member

Hey @pas-calc , I think this is a great idea! I suspect there will be some debate about how to implement this. Just to note: there's another way to do this as follows:

spike_trains_samples = si.spike_vector_to_spike_trains(sorting.to_spike_vector(concatenated=False), unit_ids=sort.unit_ids)

# go from sample index to seconds
spike_trains = {segment_id : {unit_id: spike_train/sorting.sampling_frequency for unit_id, spike_train in  spike_trains_s\
egment.items()} for segment_id, spike_trains_segment in spike_trains_samples.items()}

spike_vector_to_spike_trains uses a numba implementation so should be a lot faster. It might worth benchmarking on some real data. And you also need to worry about segments, I'm afraid.

I would also vote to call it to_spike_trains to match the to_spike_vector method.

@h-mayorquin
Copy link
Copy Markdown
Collaborator

The get_unit_spike_train is basically that dict with some extra functionality:

def get_unit_spike_train(
self,
unit_id: str | int,
segment_index: Union[int, None] = None,
start_frame: Union[int, None] = None,
end_frame: Union[int, None] = None,
return_times: bool = False,
use_cache: bool = True,
):
segment_index = self._check_segment_index(segment_index)
if use_cache:
if segment_index not in self._cached_spike_trains:
self._cached_spike_trains[segment_index] = {}
if unit_id not in self._cached_spike_trains[segment_index]:
segment = self._sorting_segments[segment_index]
spike_frames = segment.get_unit_spike_train(unit_id=unit_id, start_frame=None, end_frame=None).astype(
"int64", copy=False
)
self._cached_spike_trains[segment_index][unit_id] = spike_frames
else:
spike_frames = self._cached_spike_trains[segment_index][unit_id]
if start_frame is not None:
start = np.searchsorted(spike_frames, start_frame)
spike_frames = spike_frames[start:]
if end_frame is not None:
end = np.searchsorted(spike_frames, end_frame)
spike_frames = spike_frames[:end]
else:
segment = self._sorting_segments[segment_index]
spike_frames = segment.get_unit_spike_train(
unit_id=unit_id, start_frame=start_frame, end_frame=end_frame
).astype("int64")
if return_times:
if self.has_recording():
times = self.get_times(segment_index=segment_index)
return times[spike_frames]
else:
segment = self._sorting_segments[segment_index]
t_start = segment._t_start if segment._t_start is not None else 0
spike_times = spike_frames / self.get_sampling_frequency()
return t_start + spike_times
else:
return spike_frames

@samuelgarcia
Copy link
Copy Markdown
Member

Hi.
THanks for this proposal.
I think we need to debate on the naming. to_spike_trains_dict could be a candidate.
The missing piece here is the segment. What this function shoudl return when multi segment ? a list of dict ?
And then when unique segment ? also a list of dict with one element ?

@alejoe91 alejoe91 added the core Changes to core module label Jun 3, 2025
@h-mayorquin
Copy link
Copy Markdown
Collaborator

Thinking more about this I am against adding a method to the sorting extractor class that can be accomplished in one line:

spike_trains_dict = {unit_id: sorting.get_unit_spike_train(unit_id=unit_id, segment_index=0) for unit_id in sorting.unit_ids}

Let's not increase the API exposure area at the core level, this leaks complexity to all the other places in the library.

@alejoe91
Copy link
Copy Markdown
Member

Hi @pas-calc

Thanks for the input but we discussed internally and we agree with @h-mayorquin. Since it's a one liner to get the dict of spiketrains let's not overload the core API!

@alejoe91 alejoe91 closed this Jun 12, 2025
@pas-calc pas-calc deleted the patch-1 branch February 2, 2026 13:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

core Changes to core module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants