Skip to content

Commit f632aa7

Browse files
alejoe91zzhmarkpre-commit-ci[bot]
authored
Switch to kernel for local CAR (#4454)
Co-authored-by: ZZUOHAN <zzuohan1@jh.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 73bd245 commit f632aa7

File tree

2 files changed

+102
-32
lines changed

2 files changed

+102
-32
lines changed

src/spikeinterface/preprocessing/common_reference.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import numpy as np
1+
import warnings
22
from typing import Literal
3+
import numpy as np
34

45
from spikeinterface.core.core_tools import define_function_handling_dict_from_class
56

@@ -64,7 +65,9 @@ class CommonReferenceRecording(BasePreprocessor):
6465
annulus. The exclude radius is used to exclude channels that are too close to the reference channel and the
6566
include radius delineates the outer boundary of the annulus whose role is to exclude channels
6667
that are too far away.
67-
68+
min_local_neighbors : int, default: 5
69+
Use in the local CAR implementation to set a minimum number of neighbors. If the number of neighbors within the
70+
annulus is less than this number, then the closest neighbors are used until this number is reached.
6871
dtype : None or dtype, default: None
6972
If None the parent dtype is kept.
7073
@@ -87,10 +90,10 @@ def __init__(
8790
dtype: str | np.dtype | None = None,
8891
):
8992
num_chans = recording.get_num_channels()
90-
neighbors = None
93+
local_kernel = None
9194
# some checks
9295
if reference not in ("global", "single", "local"):
93-
raise ValueError("'reference' must be either 'global', 'single', 'local'")
96+
raise ValueError("'reference' must be either 'global', 'single', or 'local'")
9497
if operator not in ("median", "average"):
9598
raise ValueError("'operator' must be either 'median', 'average'")
9699

@@ -116,12 +119,28 @@ def __init__(
116119
elif reference == "local":
117120
assert groups is None, "With 'local' CAR, the group option should not be used."
118121
closest_inds, dist = get_closest_channels(recording)
119-
neighbors = {}
122+
# The neighbor kernel is a matrix that will be used to calculate the local reference.
123+
# It has shape (num_chans, num_chans) and is filled with zeros except for the columns corresponding to the
124+
# neighbors of each channel, which are filled with 1 / number of neighbors. This way, when we do a dot
125+
# product between the traces and the neighbor kernel, we get the local average reference for each channel.
126+
# For the median operator, the neighbors are extracted from the kernel on-the-fly via nonzero.
127+
local_kernel = np.zeros((num_chans, num_chans))
128+
not_enough_channels = []
120129
for i in range(num_chans):
121-
mask = dist[i, :] > local_radius[0]
122-
nn = np.cumsum(mask)
123-
mask &= (dist[i, :] <= local_radius[1]) | ((0 < nn) & (nn <= min_local_neighbors))
124-
neighbors[i] = closest_inds[i, mask]
130+
annulus_mask = (dist[i, :] > local_radius[0]) & (dist[i, :] <= local_radius[1])
131+
if np.sum(annulus_mask) >= min_local_neighbors:
132+
neighbors_i = closest_inds[i, annulus_mask]
133+
else:
134+
# Not enough channels in the annulus — take the closest ones beyond the inner radius
135+
not_enough_channels.append(recording.channel_ids[i])
136+
beyond_inner = dist[i, :] > local_radius[0]
137+
neighbors_i = closest_inds[i, beyond_inner][:min_local_neighbors]
138+
local_kernel[i, neighbors_i] = 1 / len(neighbors_i)
139+
if len(not_enough_channels) > 0:
140+
warnings.warn(
141+
f"The following channels did not have enough neighbors in the annulus and used the closest "
142+
f"{min_local_neighbors} channels beyond the inner radius instead: {', '.join(not_enough_channels)}"
143+
)
125144
dtype_ = fix_dtype(recording, dtype)
126145
BasePreprocessor.__init__(self, recording, dtype=dtype_)
127146

@@ -142,8 +161,7 @@ def __init__(
142161
operator,
143162
group_indices,
144163
ref_channel_indices,
145-
local_radius,
146-
neighbors,
164+
local_kernel,
147165
dtype_,
148166
)
149167
self.add_recording_segment(rec_segment)
@@ -168,8 +186,7 @@ def __init__(
168186
operator,
169187
group_indices,
170188
ref_channel_indices,
171-
local_radius,
172-
neighbors,
189+
local_kernel,
173190
dtype,
174191
):
175192
BasePreprocessorSegment.__init__(self, parent_recording_segment)
@@ -178,11 +195,11 @@ def __init__(
178195
self.operator = operator
179196
self.group_indices = group_indices
180197
self.ref_channel_indices = ref_channel_indices
181-
self.local_radius = local_radius
182-
self.neighbors = neighbors
198+
self.local_kernel = local_kernel
183199
self.temp = None
184200
self.dtype = dtype
185-
self.operator_func = operator = np.mean if self.operator == "average" else np.median
201+
self.operator = operator
202+
self.operator_func = np.mean if self.operator == "average" else np.median
186203

187204
def get_traces(self, start_frame, end_frame, channel_indices):
188205
# Let's do the case with group_indices equal None as that is easy
@@ -201,13 +218,17 @@ def get_traces(self, start_frame, end_frame, channel_indices):
201218
shift = traces[:, self.ref_channel_indices]
202219
re_referenced_traces = traces[:, channel_indices] - shift
203220
else: # then it must be local
204-
channel_indices_array = np.arange(traces.shape[1])[channel_indices]
205-
re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32")
206-
for i, channel_index in enumerate(channel_indices_array):
207-
channel_neighborhood = self.neighbors[channel_index]
208-
channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1)
209-
re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift
210-
221+
if self.operator == "median":
222+
channel_indices_array = np.arange(traces.shape[1])[channel_indices]
223+
re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32")
224+
for i, channel_index in enumerate(channel_indices_array):
225+
channel_neighborhood = np.nonzero(self.local_kernel[channel_index])[0]
226+
channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1)
227+
re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift
228+
else: # then it must be local average, use local_kernel
229+
re_referenced_traces = (
230+
traces[:, channel_indices] - traces.dot(self.local_kernel.T)[:, channel_indices]
231+
)
211232
return re_referenced_traces.astype(self.dtype, copy=False)
212233

213234
# Then the old implementation for backwards compatibility that supports grouping

src/spikeinterface/preprocessing/tests/test_common_reference.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@ def test_common_reference(recording):
2626
rec_cmr_ref = common_reference(recording, reference="global", operator="median", ref_channel_ids=["a", "b", "c"])
2727
rec_car = common_reference(recording, reference="global", operator="average")
2828
rec_sin = common_reference(recording, reference="single", ref_channel_ids=["a"])
29-
rec_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="median")
29+
rec_local_cmr = common_reference(
30+
recording, reference="local", local_radius=(25, 65), operator="median", min_local_neighbors=1
31+
)
32+
rec_local_car = common_reference(
33+
recording, reference="local", local_radius=(25, 65), operator="average", min_local_neighbors=1
34+
)
3035

3136
traces = recording.get_traces()
3237
assert np.allclose(traces, rec_cmr.get_traces() + np.median(traces, axis=1, keepdims=True), atol=0.01)
@@ -35,21 +40,26 @@ def test_common_reference(recording):
3540
assert not np.all(rec_sin.get_traces()[0])
3641
assert np.allclose(rec_sin.get_traces()[:, 1], traces[:, 1] - traces[:, 0])
3742

38-
assert np.allclose(traces[:, 0], rec_local_car.get_traces()[:, 0] + np.median(traces[:, [2, 3]], axis=1), atol=0.01)
39-
assert np.allclose(traces[:, 1], rec_local_car.get_traces()[:, 1] + np.median(traces[:, [3]], axis=1), atol=0.01)
43+
assert np.allclose(traces[:, 0], rec_local_cmr.get_traces()[:, 0] + np.median(traces[:, [2, 3]], axis=1), atol=0.01)
44+
assert np.allclose(traces[:, 1], rec_local_cmr.get_traces()[:, 1] + np.median(traces[:, [3]], axis=1), atol=0.01)
45+
46+
assert np.allclose(traces[:, 0], rec_local_car.get_traces()[:, 0] + np.mean(traces[:, [2, 3]], axis=1), atol=0.01)
47+
assert np.allclose(traces[:, 1], rec_local_car.get_traces()[:, 1] + np.mean(traces[:, [3]], axis=1), atol=0.01)
4048

4149
# Saving tests
4250
rec_cmr.save(verbose=False)
4351
rec_car.save(verbose=False)
4452
rec_sin.save(verbose=False)
53+
rec_local_cmr.save(verbose=False)
4554
rec_local_car.save(verbose=False)
4655

4756

4857
def test_common_reference_channel_slicing(recording):
4958
recording_cmr = common_reference(recording, reference="global", operator="median")
5059
recording_car = common_reference(recording, reference="global", operator="average")
5160
recording_single_reference = common_reference(recording, reference="single", ref_channel_ids=["b"])
52-
recording_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="median")
61+
recording_local_cmr = common_reference(recording, reference="local", local_radius=(20, 65), operator="median")
62+
recording_local_car = common_reference(recording, reference="local", local_radius=(20, 65), operator="average")
5363

5464
channel_ids = ["b", "d"]
5565
indices = recording.ids_to_indices(channel_ids)
@@ -73,9 +83,12 @@ def test_common_reference_channel_slicing(recording):
7383

7484
assert np.allclose(single_reference_trace, expected_trace, atol=0.01)
7585

86+
local_trace = recording_local_cmr.get_traces(channel_ids=all_channel_ids)
87+
local_trace_sub = recording_local_cmr.get_traces(channel_ids=channel_ids)
88+
assert np.all(local_trace[:, indices] == local_trace_sub)
89+
7690
local_trace = recording_local_car.get_traces(channel_ids=all_channel_ids)
7791
local_trace_sub = recording_local_car.get_traces(channel_ids=channel_ids)
78-
7992
assert np.all(local_trace[:, indices] == local_trace_sub)
8093

8194
# test segment slicing
@@ -157,8 +170,44 @@ def test_common_reference_groups(recording):
157170
assert np.allclose(traces[:, 1], 0)
158171

159172

173+
def test_min_local_radius():
174+
# Test that local radius smaller than the number of channels is handled correctly
175+
recording = generate_recording(durations=[1.0], num_channels=32)
176+
# remove closest channel to first channel
177+
recording = recording.remove_channels(recording.channel_ids[1:5])
178+
with pytest.warns(UserWarning):
179+
recording_local_car = common_reference(
180+
recording, reference="local", local_radius=(60, 150), operator="average", min_local_neighbors=5
181+
)
182+
183+
184+
@pytest.mark.skip(reason="This test can be used to check local CAR vs local CMR performance")
185+
def test_local_car_vs_cmr_performance():
186+
import time
187+
188+
# Test that local CAR is faster than local CMR when there are many channels
189+
recording = generate_recording(durations=[10.0], num_channels=384)
190+
191+
rec_local_cmr = common_reference(
192+
recording, reference="local", local_radius=(20, 65), operator="median", min_local_neighbors=1
193+
)
194+
t_start_cmr = time.perf_counter()
195+
rec_local_cmr.get_traces()
196+
t_end_cmr = time.perf_counter()
197+
cmr_time = t_end_cmr - t_start_cmr
198+
199+
rec_local_car = common_reference(
200+
recording, reference="local", local_radius=(20, 65), operator="average", min_local_neighbors=1
201+
)
202+
t_start_car = time.perf_counter()
203+
rec_local_car.get_traces()
204+
t_end_car = time.perf_counter()
205+
car_time = t_end_car - t_start_car
206+
207+
print(f"Local CMR time: {cmr_time:.4f} seconds")
208+
print(f"Local CAR time: {car_time:.4f} seconds")
209+
assert car_time < cmr_time
210+
211+
160212
if __name__ == "__main__":
161-
recording = _generate_test_recording()
162-
test_common_reference(recording)
163-
test_common_reference_channel_slicing(recording)
164-
test_common_reference_groups(recording)
213+
test_local_car_vs_cmr_performance()

0 commit comments

Comments
 (0)