Skip to content

Commit f6e45ce

Browse files
zzhmarkpre-commit-ci[bot]alejoe91
authored
Implementing KNN referencing and a backup referencing mechanism for common_reference (#4412)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Alessio Buccino <alejoe9187@gmail.com>
1 parent bd1f1ca commit f6e45ce

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

src/spikeinterface/preprocessing/common_reference.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,14 @@ def __init__(
8383
groups: list | None = None,
8484
ref_channel_ids: list | str | int | None = None,
8585
local_radius: tuple[float, float] = (30.0, 55.0),
86+
min_local_neighbors: int = 5,
8687
dtype: str | np.dtype | None = None,
8788
):
8889
num_chans = recording.get_num_channels()
8990
neighbors = None
9091
# some checks
9192
if reference not in ("global", "single", "local"):
92-
raise ValueError("'reference' must be either 'global', 'single' or 'local'")
93+
raise ValueError("'reference' must be either 'global', 'single', 'local'")
9394
if operator not in ("median", "average"):
9495
raise ValueError("'operator' must be either 'median', 'average'")
9596

@@ -117,10 +118,10 @@ def __init__(
117118
closest_inds, dist = get_closest_channels(recording)
118119
neighbors = {}
119120
for i in range(num_chans):
120-
mask = (dist[i, :] > local_radius[0]) & (dist[i, :] <= local_radius[1])
121+
mask = dist[i, :] > local_radius[0]
122+
nn = np.cumsum(mask)
123+
mask &= (dist[i, :] <= local_radius[1]) | ((0 < nn) & (nn <= min_local_neighbors))
121124
neighbors[i] = closest_inds[i, mask]
122-
assert len(neighbors[i]) > 0, "No reference channels available in the local annulus for selection."
123-
124125
dtype_ = fix_dtype(recording, dtype)
125126
BasePreprocessor.__init__(self, recording, dtype=dtype_)
126127

@@ -136,7 +137,14 @@ def __init__(
136137

137138
for parent_segment in recording._recording_segments:
138139
rec_segment = CommonReferenceRecordingSegment(
139-
parent_segment, reference, operator, group_indices, ref_channel_indices, local_radius, neighbors, dtype_
140+
parent_segment,
141+
reference,
142+
operator,
143+
group_indices,
144+
ref_channel_indices,
145+
local_radius,
146+
neighbors,
147+
dtype_,
140148
)
141149
self.add_recording_segment(rec_segment)
142150

@@ -147,6 +155,7 @@ def __init__(
147155
operator=operator,
148156
ref_channel_ids=ref_channel_ids,
149157
local_radius=local_radius,
158+
min_local_neighbors=min_local_neighbors,
150159
dtype=dtype_.str,
151160
)
152161

0 commit comments

Comments
 (0)