From 9487c57c04f20aecb640330d4b087b0e0ba931f0 Mon Sep 17 00:00:00 2001 From: ZZUOHAN Date: Thu, 26 Feb 2026 16:25:34 -0500 Subject: [PATCH 01/10] introduce backup reference and knn reference backup reference for local reference when local channels too few; knn can be more robust than local as it ensures the # of reference channels --- .../preprocessing/common_reference.py | 83 +++++++++++++++---- 1 file changed, 67 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 3b5faa1381..07e7ff7794 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -8,6 +8,7 @@ from spikeinterface.core.baserecording import BaseRecording from .filter import fix_dtype +from functools import cache class CommonReferenceRecording(BasePreprocessor): @@ -78,26 +79,30 @@ class CommonReferenceRecording(BasePreprocessor): def __init__( self, recording: BaseRecording, - reference: Literal["global", "single", "local"] = "global", + reference: Literal["global", "single", "local", 'knn'] = "global", operator: Literal["median", "average"] = "median", groups: list | None = None, ref_channel_ids: list | str | int | None = None, local_radius: tuple[float, float] = (30.0, 55.0), + nneighbors: int | None = None, + backup_reference: Literal["global", "single", "knn"] = "global", + backup_thr: int = 1, dtype: str | np.dtype | None = None, ): num_chans = recording.get_num_channels() neighbors = None + knearest_neighbors = None # some checks - if reference not in ("global", "single", "local"): - raise ValueError("'reference' must be either 'global', 'single' or 'local'") + if reference not in ("global", "single", "local", "knn"): + raise ValueError("'reference' must be either 'global', 'single', 'local' or 'knn'") if operator not in ("median", "average"): raise ValueError("'operator' must be either 'median', 'average'") - if reference == "global": + if reference == "global" or backup_reference == "global": if ref_channel_ids is not None: if not isinstance(ref_channel_ids, list): raise ValueError("With 'global' reference, provide 'ref_channel_ids' as a list") - elif reference == "single": + if reference == "single" or reference == 'single': assert ref_channel_ids is not None, "With 'single' reference, provide 'ref_channel_ids'" if groups is not None: assert len(ref_channel_ids) == len(groups), "'ref_channel_ids' and 'groups' must have the same length" @@ -112,15 +117,19 @@ def __init__( assert np.all( [ch in recording.channel_ids for ch in ref_channel_ids] ), "Some 'ref_channel_ids' are wrong!" - elif reference == "local": + if reference == "local": assert groups is None, "With 'local' CAR, the group option should not be used." closest_inds, dist = get_closest_channels(recording) neighbors = {} for i in range(num_chans): mask = (dist[i, :] > local_radius[0]) & (dist[i, :] <= local_radius[1]) neighbors[i] = closest_inds[i, mask] - assert len(neighbors[i]) > 0, "No reference channels available in the local annulus for selection." - + # assert len(neighbors[i]) > 0, "No reference channels available in the local annulus for selection." + if reference == "knn" or backup_reference == 'knn': + assert groups is None, "With 'knn' CAR, the group option should not be used." + assert nneighbors is not None, "With 'knn' reference, provide 'nneighbors'" + assert nneighbors > 0, "'nneighbors' must be positive" + knearest_neighbors, _ = get_closest_channels(recording, num_channels=min(nneighbors, num_chans)) dtype_ = fix_dtype(recording, dtype) BasePreprocessor.__init__(self, recording, dtype=dtype_) @@ -136,7 +145,7 @@ def __init__( for parent_segment in recording._recording_segments: rec_segment = CommonReferenceRecordingSegment( - parent_segment, reference, operator, group_indices, ref_channel_indices, local_radius, neighbors, dtype_ + parent_segment, reference, operator, group_indices, ref_channel_indices, local_radius, neighbors, knearest_neighbors, backup_reference, backup_thr, dtype_ ) self.add_recording_segment(rec_segment) @@ -147,6 +156,9 @@ def __init__( operator=operator, ref_channel_ids=ref_channel_ids, local_radius=local_radius, + nneighbors=nneighbors, + backup_reference=backup_reference, + backup_thr=backup_thr, dtype=dtype_.str, ) @@ -161,11 +173,17 @@ def __init__( ref_channel_indices, local_radius, neighbors, + knearest_neighbors, + backup_reference, + backup_thr, dtype, ): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.reference = reference + self.knearest_neighbors = knearest_neighbors + self.backup_reference = backup_reference + self.backup_thr = backup_thr self.operator = operator self.group_indices = group_indices self.ref_channel_indices = ref_channel_indices @@ -181,23 +199,56 @@ def get_traces(self, start_frame, end_frame, channel_indices): # We need all the channels to calculate the reference traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) - if self.reference == "global": + @cache + def _global(keepdims=True): if self.ref_channel_indices is None: - shift = self.operator_func(traces, axis=1, keepdims=True) + shift = self.operator_func(traces, axis=1, keepdims=keepdims) else: - shift = self.operator_func(traces[:, self.ref_channel_indices], axis=1, keepdims=True) - re_referenced_traces = traces[:, channel_indices] - shift - elif self.reference == "single": + shift = self.operator_func(traces[:, self.ref_channel_indices], axis=1, keepdims=keepdims) + return shift + + @cache + def _single(): # single channel -> no need of operator shift = traces[:, self.ref_channel_indices] - re_referenced_traces = traces[:, channel_indices] - shift - else: # then it must be local + return shift + + def _local(): channel_indices_array = np.arange(traces.shape[1])[channel_indices] re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32") for i, channel_index in enumerate(channel_indices_array): channel_neighborhood = self.neighbors[channel_index] + if len(channel_neighborhood) < self.backup_thr: + if self.backup_reference == 'global': + channel_shift = _global(False) + elif self.backup_reference == 'single': + channel_shift = _single() + else: + channel_neighborhood = self.knearest_neighbors[channel_index] + channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1) + else: + channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1) + re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift + return re_referenced_traces + + + def _knn(): + channel_indices_array = np.arange(traces.shape[1])[channel_indices] + re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32") + for i, channel_index in enumerate(channel_indices_array): + channel_neighborhood = self.knearest_neighbors[channel_index] channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1) re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift + return re_referenced_traces + + if self.reference == 'global': + re_referenced_traces = traces[:, channel_indices] - _global() + elif self.reference == 'single': + re_referenced_traces = traces[:, channel_indices] - _single() + elif self.reference == 'knn': + re_referenced_traces = _knn() + else: + re_referenced_traces = _local() return re_referenced_traces.astype(self.dtype, copy=False) From 013d810be15099888c1644ff8d8ecd7fe2d80b31 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Feb 2026 21:44:17 +0000 Subject: [PATCH 02/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../preprocessing/common_reference.py | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 07e7ff7794..d3f59ff799 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -79,7 +79,7 @@ class CommonReferenceRecording(BasePreprocessor): def __init__( self, recording: BaseRecording, - reference: Literal["global", "single", "local", 'knn'] = "global", + reference: Literal["global", "single", "local", "knn"] = "global", operator: Literal["median", "average"] = "median", groups: list | None = None, ref_channel_ids: list | str | int | None = None, @@ -102,7 +102,7 @@ def __init__( if ref_channel_ids is not None: if not isinstance(ref_channel_ids, list): raise ValueError("With 'global' reference, provide 'ref_channel_ids' as a list") - if reference == "single" or reference == 'single': + if reference == "single" or reference == "single": assert ref_channel_ids is not None, "With 'single' reference, provide 'ref_channel_ids'" if groups is not None: assert len(ref_channel_ids) == len(groups), "'ref_channel_ids' and 'groups' must have the same length" @@ -125,7 +125,7 @@ def __init__( mask = (dist[i, :] > local_radius[0]) & (dist[i, :] <= local_radius[1]) neighbors[i] = closest_inds[i, mask] # assert len(neighbors[i]) > 0, "No reference channels available in the local annulus for selection." - if reference == "knn" or backup_reference == 'knn': + if reference == "knn" or backup_reference == "knn": assert groups is None, "With 'knn' CAR, the group option should not be used." assert nneighbors is not None, "With 'knn' reference, provide 'nneighbors'" assert nneighbors > 0, "'nneighbors' must be positive" @@ -145,7 +145,17 @@ def __init__( for parent_segment in recording._recording_segments: rec_segment = CommonReferenceRecordingSegment( - parent_segment, reference, operator, group_indices, ref_channel_indices, local_radius, neighbors, knearest_neighbors, backup_reference, backup_thr, dtype_ + parent_segment, + reference, + operator, + group_indices, + ref_channel_indices, + local_radius, + neighbors, + knearest_neighbors, + backup_reference, + backup_thr, + dtype_, ) self.add_recording_segment(rec_segment) @@ -219,9 +229,9 @@ def _local(): for i, channel_index in enumerate(channel_indices_array): channel_neighborhood = self.neighbors[channel_index] if len(channel_neighborhood) < self.backup_thr: - if self.backup_reference == 'global': + if self.backup_reference == "global": channel_shift = _global(False) - elif self.backup_reference == 'single': + elif self.backup_reference == "single": channel_shift = _single() else: channel_neighborhood = self.knearest_neighbors[channel_index] @@ -231,7 +241,6 @@ def _local(): re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift return re_referenced_traces - def _knn(): channel_indices_array = np.arange(traces.shape[1])[channel_indices] re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32") @@ -241,11 +250,11 @@ def _knn(): re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift return re_referenced_traces - if self.reference == 'global': + if self.reference == "global": re_referenced_traces = traces[:, channel_indices] - _global() - elif self.reference == 'single': + elif self.reference == "single": re_referenced_traces = traces[:, channel_indices] - _single() - elif self.reference == 'knn': + elif self.reference == "knn": re_referenced_traces = _knn() else: re_referenced_traces = _local() From ff1811177bfc95be6de27426116107aaf095b4cd Mon Sep 17 00:00:00 2001 From: ZZUOHAN Date: Fri, 27 Feb 2026 11:49:28 -0500 Subject: [PATCH 03/10] Remove KNN as a reference method & limit backup behavior to KNN only --- .../preprocessing/common_reference.py | 86 ++++--------------- 1 file changed, 18 insertions(+), 68 deletions(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index d3f59ff799..352afba263 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -79,30 +79,27 @@ class CommonReferenceRecording(BasePreprocessor): def __init__( self, recording: BaseRecording, - reference: Literal["global", "single", "local", "knn"] = "global", + reference: Literal["global", "single", "local"] = "global", operator: Literal["median", "average"] = "median", groups: list | None = None, ref_channel_ids: list | str | int | None = None, local_radius: tuple[float, float] = (30.0, 55.0), - nneighbors: int | None = None, - backup_reference: Literal["global", "single", "knn"] = "global", - backup_thr: int = 1, + min_local_neighbors: int = 5, dtype: str | np.dtype | None = None, ): num_chans = recording.get_num_channels() neighbors = None - knearest_neighbors = None # some checks - if reference not in ("global", "single", "local", "knn"): - raise ValueError("'reference' must be either 'global', 'single', 'local' or 'knn'") + if reference not in ("global", "single", "local"): + raise ValueError("'reference' must be either 'global', 'single', 'local'") if operator not in ("median", "average"): raise ValueError("'operator' must be either 'median', 'average'") - if reference == "global" or backup_reference == "global": + if reference == "global": if ref_channel_ids is not None: if not isinstance(ref_channel_ids, list): raise ValueError("With 'global' reference, provide 'ref_channel_ids' as a list") - if reference == "single" or reference == "single": + elif reference == "single": assert ref_channel_ids is not None, "With 'single' reference, provide 'ref_channel_ids'" if groups is not None: assert len(ref_channel_ids) == len(groups), "'ref_channel_ids' and 'groups' must have the same length" @@ -117,19 +114,15 @@ def __init__( assert np.all( [ch in recording.channel_ids for ch in ref_channel_ids] ), "Some 'ref_channel_ids' are wrong!" - if reference == "local": + elif reference == "local": assert groups is None, "With 'local' CAR, the group option should not be used." closest_inds, dist = get_closest_channels(recording) neighbors = {} for i in range(num_chans): - mask = (dist[i, :] > local_radius[0]) & (dist[i, :] <= local_radius[1]) + mask = (dist[i, :] > local_radius[0]) + nn = np.cumsum(mask) + mask &= (dist[i, :] <= local_radius[1]) | ((0 < nn) & (nn <= min_local_neighbors)) neighbors[i] = closest_inds[i, mask] - # assert len(neighbors[i]) > 0, "No reference channels available in the local annulus for selection." - if reference == "knn" or backup_reference == "knn": - assert groups is None, "With 'knn' CAR, the group option should not be used." - assert nneighbors is not None, "With 'knn' reference, provide 'nneighbors'" - assert nneighbors > 0, "'nneighbors' must be positive" - knearest_neighbors, _ = get_closest_channels(recording, num_channels=min(nneighbors, num_chans)) dtype_ = fix_dtype(recording, dtype) BasePreprocessor.__init__(self, recording, dtype=dtype_) @@ -152,9 +145,6 @@ def __init__( ref_channel_indices, local_radius, neighbors, - knearest_neighbors, - backup_reference, - backup_thr, dtype_, ) self.add_recording_segment(rec_segment) @@ -166,9 +156,7 @@ def __init__( operator=operator, ref_channel_ids=ref_channel_ids, local_radius=local_radius, - nneighbors=nneighbors, - backup_reference=backup_reference, - backup_thr=backup_thr, + min_local_neighbors=min_local_neighbors, dtype=dtype_.str, ) @@ -183,17 +171,11 @@ def __init__( ref_channel_indices, local_radius, neighbors, - knearest_neighbors, - backup_reference, - backup_thr, dtype, ): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.reference = reference - self.knearest_neighbors = knearest_neighbors - self.backup_reference = backup_reference - self.backup_thr = backup_thr self.operator = operator self.group_indices = group_indices self.ref_channel_indices = ref_channel_indices @@ -209,55 +191,23 @@ def get_traces(self, start_frame, end_frame, channel_indices): # We need all the channels to calculate the reference traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None)) - @cache - def _global(keepdims=True): + if self.reference == "global": if self.ref_channel_indices is None: - shift = self.operator_func(traces, axis=1, keepdims=keepdims) + shift = self.operator_func(traces, axis=1, keepdims=True) else: - shift = self.operator_func(traces[:, self.ref_channel_indices], axis=1, keepdims=keepdims) - return shift - - @cache - def _single(): + shift = self.operator_func(traces[:, self.ref_channel_indices], axis=1, keepdims=True) + re_referenced_traces = traces[:, channel_indices] - shift + elif self.reference == "single": # single channel -> no need of operator shift = traces[:, self.ref_channel_indices] - return shift - - def _local(): + re_referenced_traces = traces[:, channel_indices] - shift + else: # then it must be local channel_indices_array = np.arange(traces.shape[1])[channel_indices] re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32") for i, channel_index in enumerate(channel_indices_array): channel_neighborhood = self.neighbors[channel_index] - if len(channel_neighborhood) < self.backup_thr: - if self.backup_reference == "global": - channel_shift = _global(False) - elif self.backup_reference == "single": - channel_shift = _single() - else: - channel_neighborhood = self.knearest_neighbors[channel_index] - channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1) - else: - channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1) - re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift - return re_referenced_traces - - def _knn(): - channel_indices_array = np.arange(traces.shape[1])[channel_indices] - re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32") - for i, channel_index in enumerate(channel_indices_array): - channel_neighborhood = self.knearest_neighbors[channel_index] channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1) re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift - return re_referenced_traces - - if self.reference == "global": - re_referenced_traces = traces[:, channel_indices] - _global() - elif self.reference == "single": - re_referenced_traces = traces[:, channel_indices] - _single() - elif self.reference == "knn": - re_referenced_traces = _knn() - else: - re_referenced_traces = _local() return re_referenced_traces.astype(self.dtype, copy=False) From 1972164caa4468db9fe142fbd0729c1c9f269f32 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Feb 2026 16:50:02 +0000 Subject: [PATCH 04/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/common_reference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 352afba263..5231f8f6c8 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -119,7 +119,7 @@ def __init__( closest_inds, dist = get_closest_channels(recording) neighbors = {} for i in range(num_chans): - mask = (dist[i, :] > local_radius[0]) + mask = dist[i, :] > local_radius[0] nn = np.cumsum(mask) mask &= (dist[i, :] <= local_radius[1]) | ((0 < nn) & (nn <= min_local_neighbors)) neighbors[i] = closest_inds[i, mask] From 4ae2861b8abf9387b52fe9a55e8d52b9ed6dfbb8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 17 Mar 2026 10:52:11 +0100 Subject: [PATCH 05/10] Apply suggestion from @alejoe91 --- src/spikeinterface/preprocessing/common_reference.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 5231f8f6c8..288889dd9e 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -8,7 +8,6 @@ from spikeinterface.core.baserecording import BaseRecording from .filter import fix_dtype -from functools import cache class CommonReferenceRecording(BasePreprocessor): From 80722c583eb3d8e1367753d364f89953bd8df91f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 17 Mar 2026 11:28:58 +0100 Subject: [PATCH 06/10] CMR: Fix local neighbor and implement local_kernel for average --- .../preprocessing/common_reference.py | 77 +++++++++++++------ .../tests/test_common_reference.py | 47 ++++++++--- 2 files changed, 91 insertions(+), 33 deletions(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 288889dd9e..fc848d4e48 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -1,5 +1,6 @@ -import numpy as np +import warnings from typing import Literal +import numpy as np from spikeinterface.core.core_tools import define_function_handling_dict_from_class @@ -64,7 +65,9 @@ class CommonReferenceRecording(BasePreprocessor): annulus. The exclude radius is used to exclude channels that are too close to the reference channel and the include radius delineates the outer boundary of the annulus whose role is to exclude channels that are too far away. - + min_local_neighbors : int, default: 5 + Use in the local CAR implementation to set a minimum number of neighbors. If the number of neighbors within the + annulus is less than this number, then the closest neighbors are used until this number is reached. dtype : None or dtype, default: None If None the parent dtype is kept. @@ -87,10 +90,11 @@ def __init__( dtype: str | np.dtype | None = None, ): num_chans = recording.get_num_channels() - neighbors = None + local_neighbors = None + local_kernel = None # some checks if reference not in ("global", "single", "local"): - raise ValueError("'reference' must be either 'global', 'single', 'local'") + raise ValueError("'reference' must be either 'global', 'single' or 'local'") if operator not in ("median", "average"): raise ValueError("'operator' must be either 'median', 'average'") @@ -114,14 +118,34 @@ def __init__( [ch in recording.channel_ids for ch in ref_channel_ids] ), "Some 'ref_channel_ids' are wrong!" elif reference == "local": + if operator == "median": + warnings.warn( + "Using a local median reference can be very computationally intensive. Consider using a local " + "average reference instead or pre-computing the local median reference and using the 'single' " + "reference option." + ) assert groups is None, "With 'local' CAR, the group option should not be used." closest_inds, dist = get_closest_channels(recording) - neighbors = {} + local_neighbors = {} + # The neighbor kernel is a matrix that will be used to calculate the local average reference. + # It has shape (num_chans, num_chans) and is filled with zeros except for the columns corresponding to the neighbors of each channel, which are filled with 1 / number of neighbors. This way, when we do a dot product between the traces and the neighbor kernel, we get the local average reference for each channel. + local_kernel = np.zeros((num_chans, num_chans)) + not_enough_channels = [] for i in range(num_chans): - mask = dist[i, :] > local_radius[0] - nn = np.cumsum(mask) - mask &= (dist[i, :] <= local_radius[1]) | ((0 < nn) & (nn <= min_local_neighbors)) - neighbors[i] = closest_inds[i, mask] + annulus_mask = (dist[i, :] > local_radius[0]) & (dist[i, :] <= local_radius[1]) + if np.sum(annulus_mask) >= min_local_neighbors: + local_neighbors[i] = closest_inds[i, annulus_mask] + else: + # Not enough channels in the annulus — take the closest ones beyond the inner radius + not_enough_channels.append(recording.channel_ids[i]) + beyond_inner = dist[i, :] > local_radius[0] + local_neighbors[i] = closest_inds[i, beyond_inner][:min_local_neighbors] + local_kernel[i, local_neighbors[i]] = 1 / len(local_neighbors[i]) + if len(not_enough_channels) > 0: + warnings.warn( + f"The following channels did not have enough neighbors in the annulus and used the closest " + f"{min_local_neighbors} channels beyond the inner radius instead: {', '.join(not_enough_channels)}" + ) dtype_ = fix_dtype(recording, dtype) BasePreprocessor.__init__(self, recording, dtype=dtype_) @@ -142,8 +166,8 @@ def __init__( operator, group_indices, ref_channel_indices, - local_radius, - neighbors, + local_neighbors, + local_kernel, dtype_, ) self.add_recording_segment(rec_segment) @@ -168,8 +192,8 @@ def __init__( operator, group_indices, ref_channel_indices, - local_radius, - neighbors, + local_neighbors, + local_kernel, dtype, ): BasePreprocessorSegment.__init__(self, parent_recording_segment) @@ -178,11 +202,12 @@ def __init__( self.operator = operator self.group_indices = group_indices self.ref_channel_indices = ref_channel_indices - self.local_radius = local_radius - self.neighbors = neighbors + self.local_neighbors = local_neighbors + self.local_kernel = local_kernel self.temp = None self.dtype = dtype - self.operator_func = operator = np.mean if self.operator == "average" else np.median + self.operator = operator + self.operator_func = np.mean if self.operator == "average" else np.median def get_traces(self, start_frame, end_frame, channel_indices): # Let's do the case with group_indices equal None as that is easy @@ -195,19 +220,23 @@ def get_traces(self, start_frame, end_frame, channel_indices): shift = self.operator_func(traces, axis=1, keepdims=True) else: shift = self.operator_func(traces[:, self.ref_channel_indices], axis=1, keepdims=True) - re_referenced_traces = traces[:, channel_indices] - shift + re_referenced_traces = traces[:, channel_indices] - shift # shift[:, channel_indices] elif self.reference == "single": # single channel -> no need of operator shift = traces[:, self.ref_channel_indices] re_referenced_traces = traces[:, channel_indices] - shift else: # then it must be local - channel_indices_array = np.arange(traces.shape[1])[channel_indices] - re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32") - for i, channel_index in enumerate(channel_indices_array): - channel_neighborhood = self.neighbors[channel_index] - channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1) - re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift - + if self.operator == "median": + channel_indices_array = np.arange(traces.shape[1])[channel_indices] + re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32") + for i, channel_index in enumerate(channel_indices_array): + channel_neighborhood = self.local_neighbors[channel_index] + channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1) + re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift + else: # then it must be local average, use local_kernel + re_referenced_traces = ( + traces[:, channel_indices] - traces.dot(self.local_kernel)[:, channel_indices] + ) return re_referenced_traces.astype(self.dtype, copy=False) # Then the old implementation for backwards compatibility that supports grouping diff --git a/src/spikeinterface/preprocessing/tests/test_common_reference.py b/src/spikeinterface/preprocessing/tests/test_common_reference.py index 8b37e7f4b9..a4ef357b88 100644 --- a/src/spikeinterface/preprocessing/tests/test_common_reference.py +++ b/src/spikeinterface/preprocessing/tests/test_common_reference.py @@ -26,7 +26,12 @@ def test_common_reference(recording): rec_cmr_ref = common_reference(recording, reference="global", operator="median", ref_channel_ids=["a", "b", "c"]) rec_car = common_reference(recording, reference="global", operator="average") rec_sin = common_reference(recording, reference="single", ref_channel_ids=["a"]) - rec_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="median") + rec_local_cmr = common_reference( + recording, reference="local", local_radius=(25, 65), operator="median", min_local_neighbors=1 + ) + rec_local_car = common_reference( + recording, reference="local", local_radius=(25, 65), operator="average", min_local_neighbors=1 + ) traces = recording.get_traces() assert np.allclose(traces, rec_cmr.get_traces() + np.median(traces, axis=1, keepdims=True), atol=0.01) @@ -35,13 +40,18 @@ def test_common_reference(recording): assert not np.all(rec_sin.get_traces()[0]) assert np.allclose(rec_sin.get_traces()[:, 1], traces[:, 1] - traces[:, 0]) - assert np.allclose(traces[:, 0], rec_local_car.get_traces()[:, 0] + np.median(traces[:, [2, 3]], axis=1), atol=0.01) - assert np.allclose(traces[:, 1], rec_local_car.get_traces()[:, 1] + np.median(traces[:, [3]], axis=1), atol=0.01) + assert np.allclose(traces[:, 0], rec_local_cmr.get_traces()[:, 0] + np.median(traces[:, [2, 3]], axis=1), atol=0.01) + assert np.allclose(traces[:, 1], rec_local_cmr.get_traces()[:, 1] + np.median(traces[:, [3]], axis=1), atol=0.01) + + # TODO: fix this!!! + assert np.allclose(traces[:, 0], rec_local_car.get_traces()[:, 0] + np.mean(traces[:, [2, 3]], axis=1), atol=0.01) + assert np.allclose(traces[:, 1], rec_local_car.get_traces()[:, 1] + np.mean(traces[:, [3]], axis=1), atol=0.01) # Saving tests rec_cmr.save(verbose=False) rec_car.save(verbose=False) rec_sin.save(verbose=False) + rec_local_cmr.save(verbose=False) rec_local_car.save(verbose=False) @@ -49,7 +59,8 @@ def test_common_reference_channel_slicing(recording): recording_cmr = common_reference(recording, reference="global", operator="median") recording_car = common_reference(recording, reference="global", operator="average") recording_single_reference = common_reference(recording, reference="single", ref_channel_ids=["b"]) - recording_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="median") + recording_local_cmr = common_reference(recording, reference="local", local_radius=(20, 65), operator="median") + recording_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="average") channel_ids = ["b", "d"] indices = recording.ids_to_indices(channel_ids) @@ -73,9 +84,12 @@ def test_common_reference_channel_slicing(recording): assert np.allclose(single_reference_trace, expected_trace, atol=0.01) + local_trace = recording_local_cmr.get_traces(channel_ids=all_channel_ids) + local_trace_sub = recording_local_cmr.get_traces(channel_ids=channel_ids) + assert np.all(local_trace[:, indices] == local_trace_sub) + local_trace = recording_local_car.get_traces(channel_ids=all_channel_ids) local_trace_sub = recording_local_car.get_traces(channel_ids=channel_ids) - assert np.all(local_trace[:, indices] == local_trace_sub) # test segment slicing @@ -157,8 +171,23 @@ def test_common_reference_groups(recording): assert np.allclose(traces[:, 1], 0) +def test_min_local_radius(): + # Test that local radius smaller than the number of channels is handled correctly + recording = generate_recording(durations=[1.0], num_channels=32) + # remove closest channel to first channel + recording = recording.remove_channels(recording.channel_ids[1:5]) + with pytest.warns(UserWarning): + recording_local_car = common_reference( + recording, reference="local", local_radius=(60, 150), operator="average", min_local_neighbors=5 + ) + + # traces = recording.get_traces() + # assert np.allclose(traces[:, 0], recording_local_cmr.get_traces()[:, 0] + np.median(traces[:, [1]], axis=1), atol=0.01) + # assert np.allclose(traces[:, 1], recording_local_cmr.get_traces()[:, 1] + np.median(traces[:, [0, 2]], axis=1), atol=0.01) + + # assert np.allclose(traces[:, 0], recording_local_car.get_traces()[:, 0] + np.mean(traces[:, [1]], axis=1), atol=0.01) + # assert np.allclose(traces[:, 1], recording_local_car.get_traces()[:, 1] + + + if __name__ == "__main__": - recording = _generate_test_recording() - test_common_reference(recording) - test_common_reference_channel_slicing(recording) - test_common_reference_groups(recording) + test_min_local_radius() From e88dfeae43778be240aec5e6bfea137ecb1db1c8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 17 Mar 2026 17:21:33 +0100 Subject: [PATCH 07/10] fix: transpose local kernel --- src/spikeinterface/preprocessing/common_reference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index fc848d4e48..dcc30dc804 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -235,7 +235,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift else: # then it must be local average, use local_kernel re_referenced_traces = ( - traces[:, channel_indices] - traces.dot(self.local_kernel)[:, channel_indices] + traces[:, channel_indices] - traces.dot(self.local_kernel.T)[:, channel_indices] ) return re_referenced_traces.astype(self.dtype, copy=False) From c39d0ae70e107a9c4d72c0836a6650e47a2b49f1 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 18 Mar 2026 15:20:22 +0100 Subject: [PATCH 08/10] Apply suggestion from @alejoe91 --- src/spikeinterface/preprocessing/common_reference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index df2f5d4706..9c6240a182 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -218,7 +218,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): shift = self.operator_func(traces, axis=1, keepdims=True) else: shift = self.operator_func(traces[:, self.ref_channel_indices], axis=1, keepdims=True) - re_referenced_traces = traces[:, channel_indices] - shift # shift[:, channel_indices] + re_referenced_traces = traces[:, channel_indices] - shift elif self.reference == "single": # single channel -> no need of operator shift = traces[:, self.ref_channel_indices] From 60a6f7d759c6c177a90d3da6f081f7535a813f5f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 19 Mar 2026 10:22:42 +0100 Subject: [PATCH 09/10] Apply suggestion from @alejoe91 --- src/spikeinterface/preprocessing/tests/test_common_reference.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/tests/test_common_reference.py b/src/spikeinterface/preprocessing/tests/test_common_reference.py index aa56b081fe..3fbc260b5f 100644 --- a/src/spikeinterface/preprocessing/tests/test_common_reference.py +++ b/src/spikeinterface/preprocessing/tests/test_common_reference.py @@ -43,7 +43,6 @@ def test_common_reference(recording): assert np.allclose(traces[:, 0], rec_local_cmr.get_traces()[:, 0] + np.median(traces[:, [2, 3]], axis=1), atol=0.01) assert np.allclose(traces[:, 1], rec_local_cmr.get_traces()[:, 1] + np.median(traces[:, [3]], axis=1), atol=0.01) - # TODO: fix this!!! assert np.allclose(traces[:, 0], rec_local_car.get_traces()[:, 0] + np.mean(traces[:, [2, 3]], axis=1), atol=0.01) assert np.allclose(traces[:, 1], rec_local_car.get_traces()[:, 1] + np.mean(traces[:, [3]], axis=1), atol=0.01) From 59b26def857a3b48f46232b5e31937cfee88b4aa Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 19 Mar 2026 10:46:09 +0100 Subject: [PATCH 10/10] Apply suggestion from @alejoe91 --- src/spikeinterface/preprocessing/common_reference.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 9c6240a182..b1469a0250 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -117,12 +117,6 @@ def __init__( [ch in recording.channel_ids for ch in ref_channel_ids] ), "Some 'ref_channel_ids' are wrong!" elif reference == "local": - if operator == "median": - warnings.warn( - "Using a local median reference can be very computationally intensive. Consider using a local " - "average reference instead or pre-computing the local median reference and using the 'single' " - "reference option." - ) assert groups is None, "With 'local' CAR, the group option should not be used." closest_inds, dist = get_closest_channels(recording) # The neighbor kernel is a matrix that will be used to calculate the local reference.