Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 44 additions & 23 deletions src/spikeinterface/preprocessing/common_reference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import warnings
from typing import Literal
import numpy as np

from spikeinterface.core.core_tools import define_function_handling_dict_from_class

Expand Down Expand Up @@ -64,7 +65,9 @@ class CommonReferenceRecording(BasePreprocessor):
annulus. The exclude radius is used to exclude channels that are too close to the reference channel and the
include radius delineates the outer boundary of the annulus whose role is to exclude channels
that are too far away.

min_local_neighbors : int, default: 5
Use in the local CAR implementation to set a minimum number of neighbors. If the number of neighbors within the
annulus is less than this number, then the closest neighbors are used until this number is reached.
dtype : None or dtype, default: None
If None the parent dtype is kept.

Expand All @@ -87,10 +90,10 @@ def __init__(
dtype: str | np.dtype | None = None,
):
num_chans = recording.get_num_channels()
neighbors = None
local_kernel = None
# some checks
if reference not in ("global", "single", "local"):
raise ValueError("'reference' must be either 'global', 'single', 'local'")
raise ValueError("'reference' must be either 'global', 'single', or 'local'")
if operator not in ("median", "average"):
raise ValueError("'operator' must be either 'median', 'average'")

Expand All @@ -116,12 +119,28 @@ def __init__(
elif reference == "local":
assert groups is None, "With 'local' CAR, the group option should not be used."
closest_inds, dist = get_closest_channels(recording)
neighbors = {}
# The neighbor kernel is a matrix that will be used to calculate the local reference.
# It has shape (num_chans, num_chans) and is filled with zeros except for the columns corresponding to the
# neighbors of each channel, which are filled with 1 / number of neighbors. This way, when we do a dot
# product between the traces and the neighbor kernel, we get the local average reference for each channel.
# For the median operator, the neighbors are extracted from the kernel on-the-fly via nonzero.
local_kernel = np.zeros((num_chans, num_chans))
not_enough_channels = []
for i in range(num_chans):
mask = dist[i, :] > local_radius[0]
nn = np.cumsum(mask)
mask &= (dist[i, :] <= local_radius[1]) | ((0 < nn) & (nn <= min_local_neighbors))
neighbors[i] = closest_inds[i, mask]
annulus_mask = (dist[i, :] > local_radius[0]) & (dist[i, :] <= local_radius[1])
if np.sum(annulus_mask) >= min_local_neighbors:
neighbors_i = closest_inds[i, annulus_mask]
else:
# Not enough channels in the annulus — take the closest ones beyond the inner radius
not_enough_channels.append(recording.channel_ids[i])
beyond_inner = dist[i, :] > local_radius[0]
neighbors_i = closest_inds[i, beyond_inner][:min_local_neighbors]
local_kernel[i, neighbors_i] = 1 / len(neighbors_i)
if len(not_enough_channels) > 0:
warnings.warn(
f"The following channels did not have enough neighbors in the annulus and used the closest "
f"{min_local_neighbors} channels beyond the inner radius instead: {', '.join(not_enough_channels)}"
)
dtype_ = fix_dtype(recording, dtype)
BasePreprocessor.__init__(self, recording, dtype=dtype_)

Expand All @@ -142,8 +161,7 @@ def __init__(
operator,
group_indices,
ref_channel_indices,
local_radius,
neighbors,
local_kernel,
dtype_,
)
self.add_recording_segment(rec_segment)
Expand All @@ -168,8 +186,7 @@ def __init__(
operator,
group_indices,
ref_channel_indices,
local_radius,
neighbors,
local_kernel,
dtype,
):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
Expand All @@ -178,11 +195,11 @@ def __init__(
self.operator = operator
self.group_indices = group_indices
self.ref_channel_indices = ref_channel_indices
self.local_radius = local_radius
self.neighbors = neighbors
self.local_kernel = local_kernel
self.temp = None
self.dtype = dtype
self.operator_func = operator = np.mean if self.operator == "average" else np.median
self.operator = operator
self.operator_func = np.mean if self.operator == "average" else np.median

def get_traces(self, start_frame, end_frame, channel_indices):
# Let's do the case with group_indices equal None as that is easy
Expand All @@ -201,13 +218,17 @@ def get_traces(self, start_frame, end_frame, channel_indices):
shift = traces[:, self.ref_channel_indices]
re_referenced_traces = traces[:, channel_indices] - shift
else: # then it must be local
channel_indices_array = np.arange(traces.shape[1])[channel_indices]
re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32")
for i, channel_index in enumerate(channel_indices_array):
channel_neighborhood = self.neighbors[channel_index]
channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1)
re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift

if self.operator == "median":
channel_indices_array = np.arange(traces.shape[1])[channel_indices]
re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32")
for i, channel_index in enumerate(channel_indices_array):
channel_neighborhood = np.nonzero(self.local_kernel[channel_index])[0]
channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1)
re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift
else: # then it must be local average, use local_kernel
re_referenced_traces = (
traces[:, channel_indices] - traces.dot(self.local_kernel.T)[:, channel_indices]
)
return re_referenced_traces.astype(self.dtype, copy=False)

# Then the old implementation for backwards compatibility that supports grouping
Expand Down
67 changes: 58 additions & 9 deletions src/spikeinterface/preprocessing/tests/test_common_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ def test_common_reference(recording):
rec_cmr_ref = common_reference(recording, reference="global", operator="median", ref_channel_ids=["a", "b", "c"])
rec_car = common_reference(recording, reference="global", operator="average")
rec_sin = common_reference(recording, reference="single", ref_channel_ids=["a"])
rec_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="median")
rec_local_cmr = common_reference(
recording, reference="local", local_radius=(25, 65), operator="median", min_local_neighbors=1
)
rec_local_car = common_reference(
recording, reference="local", local_radius=(25, 65), operator="average", min_local_neighbors=1
)

traces = recording.get_traces()
assert np.allclose(traces, rec_cmr.get_traces() + np.median(traces, axis=1, keepdims=True), atol=0.01)
Expand All @@ -35,21 +40,26 @@ def test_common_reference(recording):
assert not np.all(rec_sin.get_traces()[0])
assert np.allclose(rec_sin.get_traces()[:, 1], traces[:, 1] - traces[:, 0])

assert np.allclose(traces[:, 0], rec_local_car.get_traces()[:, 0] + np.median(traces[:, [2, 3]], axis=1), atol=0.01)
assert np.allclose(traces[:, 1], rec_local_car.get_traces()[:, 1] + np.median(traces[:, [3]], axis=1), atol=0.01)
assert np.allclose(traces[:, 0], rec_local_cmr.get_traces()[:, 0] + np.median(traces[:, [2, 3]], axis=1), atol=0.01)
assert np.allclose(traces[:, 1], rec_local_cmr.get_traces()[:, 1] + np.median(traces[:, [3]], axis=1), atol=0.01)

assert np.allclose(traces[:, 0], rec_local_car.get_traces()[:, 0] + np.mean(traces[:, [2, 3]], axis=1), atol=0.01)
assert np.allclose(traces[:, 1], rec_local_car.get_traces()[:, 1] + np.mean(traces[:, [3]], axis=1), atol=0.01)

# Saving tests
rec_cmr.save(verbose=False)
rec_car.save(verbose=False)
rec_sin.save(verbose=False)
rec_local_cmr.save(verbose=False)
rec_local_car.save(verbose=False)


def test_common_reference_channel_slicing(recording):
recording_cmr = common_reference(recording, reference="global", operator="median")
recording_car = common_reference(recording, reference="global", operator="average")
recording_single_reference = common_reference(recording, reference="single", ref_channel_ids=["b"])
recording_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="median")
recording_local_cmr = common_reference(recording, reference="local", local_radius=(20, 65), operator="median")
recording_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="average")

channel_ids = ["b", "d"]
indices = recording.ids_to_indices(channel_ids)
Expand All @@ -73,9 +83,12 @@ def test_common_reference_channel_slicing(recording):

assert np.allclose(single_reference_trace, expected_trace, atol=0.01)

local_trace = recording_local_cmr.get_traces(channel_ids=all_channel_ids)
local_trace_sub = recording_local_cmr.get_traces(channel_ids=channel_ids)
assert np.all(local_trace[:, indices] == local_trace_sub)

local_trace = recording_local_car.get_traces(channel_ids=all_channel_ids)
local_trace_sub = recording_local_car.get_traces(channel_ids=channel_ids)

assert np.all(local_trace[:, indices] == local_trace_sub)

# test segment slicing
Expand Down Expand Up @@ -157,8 +170,44 @@ def test_common_reference_groups(recording):
assert np.allclose(traces[:, 1], 0)


def test_min_local_radius():
# Test that local radius smaller than the number of channels is handled correctly
recording = generate_recording(durations=[1.0], num_channels=32)
# remove closest channel to first channel
recording = recording.remove_channels(recording.channel_ids[1:5])
with pytest.warns(UserWarning):
recording_local_car = common_reference(
recording, reference="local", local_radius=(60, 150), operator="average", min_local_neighbors=5
)


@pytest.mark.skip(reason="This test can be used to check local CAR vs local CMR performance")
def test_local_car_vs_cmr_performance():
import time

# Test that local CAR is faster than local CMR when there are many channels
recording = generate_recording(durations=[10.0], num_channels=384)

rec_local_cmr = common_reference(
recording, reference="local", local_radius=(20, 65), operator="median", min_local_neighbors=1
)
t_start_cmr = time.perf_counter()
rec_local_cmr.get_traces()
t_end_cmr = time.perf_counter()
cmr_time = t_end_cmr - t_start_cmr

rec_local_car = common_reference(
recording, reference="local", local_radius=(20, 65), operator="average", min_local_neighbors=1
)
t_start_car = time.perf_counter()
rec_local_car.get_traces()
t_end_car = time.perf_counter()
car_time = t_end_car - t_start_car

print(f"Local CMR time: {cmr_time:.4f} seconds")
print(f"Local CAR time: {car_time:.4f} seconds")
assert car_time < cmr_time


if __name__ == "__main__":
recording = _generate_test_recording()
test_common_reference(recording)
test_common_reference_channel_slicing(recording)
test_common_reference_groups(recording)
test_local_car_vs_cmr_performance()
Loading