diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 288889dd9e..b1469a0250 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -1,5 +1,6 @@ -import numpy as np +import warnings from typing import Literal +import numpy as np from spikeinterface.core.core_tools import define_function_handling_dict_from_class @@ -64,7 +65,9 @@ class CommonReferenceRecording(BasePreprocessor): annulus. The exclude radius is used to exclude channels that are too close to the reference channel and the include radius delineates the outer boundary of the annulus whose role is to exclude channels that are too far away. - + min_local_neighbors : int, default: 5 + Use in the local CAR implementation to set a minimum number of neighbors. If the number of neighbors within the + annulus is less than this number, then the closest neighbors are used until this number is reached. dtype : None or dtype, default: None If None the parent dtype is kept. @@ -87,10 +90,10 @@ def __init__( dtype: str | np.dtype | None = None, ): num_chans = recording.get_num_channels() - neighbors = None + local_kernel = None # some checks if reference not in ("global", "single", "local"): - raise ValueError("'reference' must be either 'global', 'single', 'local'") + raise ValueError("'reference' must be either 'global', 'single', or 'local'") if operator not in ("median", "average"): raise ValueError("'operator' must be either 'median', 'average'") @@ -116,12 +119,28 @@ def __init__( elif reference == "local": assert groups is None, "With 'local' CAR, the group option should not be used." closest_inds, dist = get_closest_channels(recording) - neighbors = {} + # The neighbor kernel is a matrix that will be used to calculate the local reference. + # It has shape (num_chans, num_chans) and is filled with zeros except for the columns corresponding to the + # neighbors of each channel, which are filled with 1 / number of neighbors. This way, when we do a dot + # product between the traces and the neighbor kernel, we get the local average reference for each channel. + # For the median operator, the neighbors are extracted from the kernel on-the-fly via nonzero. + local_kernel = np.zeros((num_chans, num_chans)) + not_enough_channels = [] for i in range(num_chans): - mask = dist[i, :] > local_radius[0] - nn = np.cumsum(mask) - mask &= (dist[i, :] <= local_radius[1]) | ((0 < nn) & (nn <= min_local_neighbors)) - neighbors[i] = closest_inds[i, mask] + annulus_mask = (dist[i, :] > local_radius[0]) & (dist[i, :] <= local_radius[1]) + if np.sum(annulus_mask) >= min_local_neighbors: + neighbors_i = closest_inds[i, annulus_mask] + else: + # Not enough channels in the annulus — take the closest ones beyond the inner radius + not_enough_channels.append(recording.channel_ids[i]) + beyond_inner = dist[i, :] > local_radius[0] + neighbors_i = closest_inds[i, beyond_inner][:min_local_neighbors] + local_kernel[i, neighbors_i] = 1 / len(neighbors_i) + if len(not_enough_channels) > 0: + warnings.warn( + f"The following channels did not have enough neighbors in the annulus and used the closest " + f"{min_local_neighbors} channels beyond the inner radius instead: {', '.join(not_enough_channels)}" + ) dtype_ = fix_dtype(recording, dtype) BasePreprocessor.__init__(self, recording, dtype=dtype_) @@ -142,8 +161,7 @@ def __init__( operator, group_indices, ref_channel_indices, - local_radius, - neighbors, + local_kernel, dtype_, ) self.add_recording_segment(rec_segment) @@ -168,8 +186,7 @@ def __init__( operator, group_indices, ref_channel_indices, - local_radius, - neighbors, + local_kernel, dtype, ): BasePreprocessorSegment.__init__(self, parent_recording_segment) @@ -178,11 +195,11 @@ def __init__( self.operator = operator self.group_indices = group_indices self.ref_channel_indices = ref_channel_indices - self.local_radius = local_radius - self.neighbors = neighbors + self.local_kernel = local_kernel self.temp = None self.dtype = dtype - self.operator_func = operator = np.mean if self.operator == "average" else np.median + self.operator = operator + self.operator_func = np.mean if self.operator == "average" else np.median def get_traces(self, start_frame, end_frame, channel_indices): # Let's do the case with group_indices equal None as that is easy @@ -201,13 +218,17 @@ def get_traces(self, start_frame, end_frame, channel_indices): shift = traces[:, self.ref_channel_indices] re_referenced_traces = traces[:, channel_indices] - shift else: # then it must be local - channel_indices_array = np.arange(traces.shape[1])[channel_indices] - re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32") - for i, channel_index in enumerate(channel_indices_array): - channel_neighborhood = self.neighbors[channel_index] - channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1) - re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift - + if self.operator == "median": + channel_indices_array = np.arange(traces.shape[1])[channel_indices] + re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32") + for i, channel_index in enumerate(channel_indices_array): + channel_neighborhood = np.nonzero(self.local_kernel[channel_index])[0] + channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1) + re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift + else: # then it must be local average, use local_kernel + re_referenced_traces = ( + traces[:, channel_indices] - traces.dot(self.local_kernel.T)[:, channel_indices] + ) return re_referenced_traces.astype(self.dtype, copy=False) # Then the old implementation for backwards compatibility that supports grouping diff --git a/src/spikeinterface/preprocessing/tests/test_common_reference.py b/src/spikeinterface/preprocessing/tests/test_common_reference.py index 8b37e7f4b9..3fbc260b5f 100644 --- a/src/spikeinterface/preprocessing/tests/test_common_reference.py +++ b/src/spikeinterface/preprocessing/tests/test_common_reference.py @@ -26,7 +26,12 @@ def test_common_reference(recording): rec_cmr_ref = common_reference(recording, reference="global", operator="median", ref_channel_ids=["a", "b", "c"]) rec_car = common_reference(recording, reference="global", operator="average") rec_sin = common_reference(recording, reference="single", ref_channel_ids=["a"]) - rec_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="median") + rec_local_cmr = common_reference( + recording, reference="local", local_radius=(25, 65), operator="median", min_local_neighbors=1 + ) + rec_local_car = common_reference( + recording, reference="local", local_radius=(25, 65), operator="average", min_local_neighbors=1 + ) traces = recording.get_traces() assert np.allclose(traces, rec_cmr.get_traces() + np.median(traces, axis=1, keepdims=True), atol=0.01) @@ -35,13 +40,17 @@ def test_common_reference(recording): assert not np.all(rec_sin.get_traces()[0]) assert np.allclose(rec_sin.get_traces()[:, 1], traces[:, 1] - traces[:, 0]) - assert np.allclose(traces[:, 0], rec_local_car.get_traces()[:, 0] + np.median(traces[:, [2, 3]], axis=1), atol=0.01) - assert np.allclose(traces[:, 1], rec_local_car.get_traces()[:, 1] + np.median(traces[:, [3]], axis=1), atol=0.01) + assert np.allclose(traces[:, 0], rec_local_cmr.get_traces()[:, 0] + np.median(traces[:, [2, 3]], axis=1), atol=0.01) + assert np.allclose(traces[:, 1], rec_local_cmr.get_traces()[:, 1] + np.median(traces[:, [3]], axis=1), atol=0.01) + + assert np.allclose(traces[:, 0], rec_local_car.get_traces()[:, 0] + np.mean(traces[:, [2, 3]], axis=1), atol=0.01) + assert np.allclose(traces[:, 1], rec_local_car.get_traces()[:, 1] + np.mean(traces[:, [3]], axis=1), atol=0.01) # Saving tests rec_cmr.save(verbose=False) rec_car.save(verbose=False) rec_sin.save(verbose=False) + rec_local_cmr.save(verbose=False) rec_local_car.save(verbose=False) @@ -49,7 +58,8 @@ def test_common_reference_channel_slicing(recording): recording_cmr = common_reference(recording, reference="global", operator="median") recording_car = common_reference(recording, reference="global", operator="average") recording_single_reference = common_reference(recording, reference="single", ref_channel_ids=["b"]) - recording_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="median") + recording_local_cmr = common_reference(recording, reference="local", local_radius=(20, 65), operator="median") + recording_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="average") channel_ids = ["b", "d"] indices = recording.ids_to_indices(channel_ids) @@ -73,9 +83,12 @@ def test_common_reference_channel_slicing(recording): assert np.allclose(single_reference_trace, expected_trace, atol=0.01) + local_trace = recording_local_cmr.get_traces(channel_ids=all_channel_ids) + local_trace_sub = recording_local_cmr.get_traces(channel_ids=channel_ids) + assert np.all(local_trace[:, indices] == local_trace_sub) + local_trace = recording_local_car.get_traces(channel_ids=all_channel_ids) local_trace_sub = recording_local_car.get_traces(channel_ids=channel_ids) - assert np.all(local_trace[:, indices] == local_trace_sub) # test segment slicing @@ -157,8 +170,44 @@ def test_common_reference_groups(recording): assert np.allclose(traces[:, 1], 0) +def test_min_local_radius(): + # Test that local radius smaller than the number of channels is handled correctly + recording = generate_recording(durations=[1.0], num_channels=32) + # remove closest channel to first channel + recording = recording.remove_channels(recording.channel_ids[1:5]) + with pytest.warns(UserWarning): + recording_local_car = common_reference( + recording, reference="local", local_radius=(60, 150), operator="average", min_local_neighbors=5 + ) + + +@pytest.mark.skip(reason="This test can be used to check local CAR vs local CMR performance") +def test_local_car_vs_cmr_performance(): + import time + + # Test that local CAR is faster than local CMR when there are many channels + recording = generate_recording(durations=[10.0], num_channels=384) + + rec_local_cmr = common_reference( + recording, reference="local", local_radius=(20, 65), operator="median", min_local_neighbors=1 + ) + t_start_cmr = time.perf_counter() + rec_local_cmr.get_traces() + t_end_cmr = time.perf_counter() + cmr_time = t_end_cmr - t_start_cmr + + rec_local_car = common_reference( + recording, reference="local", local_radius=(20, 65), operator="average", min_local_neighbors=1 + ) + t_start_car = time.perf_counter() + rec_local_car.get_traces() + t_end_car = time.perf_counter() + car_time = t_end_car - t_start_car + + print(f"Local CMR time: {cmr_time:.4f} seconds") + print(f"Local CAR time: {car_time:.4f} seconds") + assert car_time < cmr_time + + if __name__ == "__main__": - recording = _generate_test_recording() - test_common_reference(recording) - test_common_reference_channel_slicing(recording) - test_common_reference_groups(recording) + test_local_car_vs_cmr_performance()