-
Notifications
You must be signed in to change notification settings - Fork 259
Expand file tree
/
Copy pathcommon_reference.py
More file actions
354 lines (314 loc) · 17.2 KB
/
common_reference.py
File metadata and controls
354 lines (314 loc) · 17.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
import threading
import warnings
import weakref
from typing import Literal
import numpy as np
from spikeinterface.core.core_tools import define_function_handling_dict_from_class
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment
from spikeinterface.core import get_closest_channels
from spikeinterface.core.baserecording import BaseRecording
from .filter import fix_dtype
class CommonReferenceRecording(BasePreprocessor):
"""
Re-references the recording extractor traces. That is, the value of the traces are
shifted so the there is a new zero (reference).
The new reference can be estimated either by using a common median reference (CMR) or
a common average reference (CAR).
The new reference can be set three ways:
* "global": the median/average of all channels is set as the new reference.
In this case, the 'global' median/average is subtracted from all channels.
* "single": In the simplest case, a single channel from the recording is set as the new reference.
This channel is subtracted from all other channels. To use this option, the `ref_channel_ids` argument
is used with a single channel id. Note that this option will zero out the reference channel.
A collection of channels can also be used as the new reference. In this case, the median/average of the
selected channels is subtracted from all other channels. To use this option, pass the group of channels as
a list in `ref_channel_ids`.
* "local": the median/average within an annulus is set as the new reference.
The parameters of the annulus are specified using the `local_radius` argument. With this option, both
channels which are too close and channels which are too far are excluded from the median/average. Note
that setting the `local_radius` to (0, exclude_radius) will yield a simple circular local region.
Parameters
----------
recording : RecordingExtractor
The recording extractor to be re-referenced
reference : "global" | "single" | "local", default: "global"
If "global" the reference is the average or median across all the channels. To select a subset of channels,
you can use the `ref_channel_ids` parameter.
If "single", the reference is a single channel or a list of channels that need to be set with the `ref_channel_ids`.
If "local", the reference is the set of channels within an annulus that must be set with the `local_radius` parameter.
operator : "median" | "average", default: "median"
If "median", a common median reference (CMR) is implemented (the median of
the selected channels is removed for each timestamp).
If "average", common average reference (CAR) is implemented (the mean of the selected channels is removed
for each timestamp).
groups : list or None, default: None
List of lists containing the channel ids for splitting the reference. The CMR, CAR, or referencing with respect to
single channels are applied group-wise. However, this is not applied for the local CAR.
It is useful when dealing with different channel groups, e.g. multiple tetrodes.
ref_channel_ids : list | str | int | None, default: None
If "global" reference, a list of channels to be used as reference.
If "single" reference, a list of one channel or a single channel id is expected.
If "groups" is provided, then a list of channels to be applied to each group is expected.
local_radius : tuple(int, int), default: (30, 55)
Use in the local CAR implementation as the selecting annulus with the following format:
`(exclude radius, include radius)`
Where the exlude radius is the inner radius of the annulus and the include radius is the outer radius of the
annulus. The exclude radius is used to exclude channels that are too close to the reference channel and the
include radius delineates the outer boundary of the annulus whose role is to exclude channels
that are too far away.
min_local_neighbors : int, default: 5
Use in the local CAR implementation to set a minimum number of neighbors. If the number of neighbors within the
annulus is less than this number, then the closest neighbors are used until this number is reached.
dtype : None or dtype, default: None
If None the parent dtype is kept.
Returns
-------
referenced_recording : CommonReferenceRecording
The re-referenced recording extractor object
"""
def __init__(
self,
recording: BaseRecording,
reference: Literal["global", "single", "local"] = "global",
operator: Literal["median", "average"] = "median",
groups: list | None = None,
ref_channel_ids: list | str | int | None = None,
local_radius: tuple[float, float] = (30.0, 55.0),
min_local_neighbors: int = 5,
dtype: str | np.dtype | None = None,
n_workers: int = 1,
):
num_chans = recording.get_num_channels()
local_kernel = None
# some checks
if reference not in ("global", "single", "local"):
raise ValueError("'reference' must be either 'global', 'single', or 'local'")
if operator not in ("median", "average"):
raise ValueError("'operator' must be either 'median', 'average'")
if reference == "global":
if ref_channel_ids is not None:
if not isinstance(ref_channel_ids, list):
raise ValueError("With 'global' reference, provide 'ref_channel_ids' as a list")
elif reference == "single":
assert ref_channel_ids is not None, "With 'single' reference, provide 'ref_channel_ids'"
if groups is not None:
assert len(ref_channel_ids) == len(groups), "'ref_channel_ids' and 'groups' must have the same length"
else:
if np.isscalar(ref_channel_ids):
ref_channel_ids = [ref_channel_ids]
else:
assert (
len(ref_channel_ids) == 1
), "'ref_channel_ids' with no 'groups' must be a single channel id or a list of one element"
ref_channel_ids = np.asarray(ref_channel_ids)
assert np.all(
[ch in recording.channel_ids for ch in ref_channel_ids]
), "Some 'ref_channel_ids' are wrong!"
elif reference == "local":
assert groups is None, "With 'local' CAR, the group option should not be used."
closest_inds, dist = get_closest_channels(recording)
# The neighbor kernel is a matrix that will be used to calculate the local reference.
# It has shape (num_chans, num_chans) and is filled with zeros except for the columns corresponding to the
# neighbors of each channel, which are filled with 1 / number of neighbors. This way, when we do a dot
# product between the traces and the neighbor kernel, we get the local average reference for each channel.
# For the median operator, the neighbors are extracted from the kernel on-the-fly via nonzero.
local_kernel = np.zeros((num_chans, num_chans))
not_enough_channels = []
for i in range(num_chans):
annulus_mask = (dist[i, :] > local_radius[0]) & (dist[i, :] <= local_radius[1])
if np.sum(annulus_mask) >= min_local_neighbors:
neighbors_i = closest_inds[i, annulus_mask]
else:
# Not enough channels in the annulus — take the closest ones beyond the inner radius
not_enough_channels.append(recording.channel_ids[i])
beyond_inner = dist[i, :] > local_radius[0]
neighbors_i = closest_inds[i, beyond_inner][:min_local_neighbors]
local_kernel[i, neighbors_i] = 1 / len(neighbors_i)
if len(not_enough_channels) > 0:
warnings.warn(
f"The following channels did not have enough neighbors in the annulus and used the closest "
f"{min_local_neighbors} channels beyond the inner radius instead: {', '.join(not_enough_channels)}"
)
dtype_ = fix_dtype(recording, dtype)
BasePreprocessor.__init__(self, recording, dtype=dtype_)
# tranforms groups (ids) to groups (indices)
if groups is not None:
group_indices = [self.ids_to_indices(g) for g in groups]
else:
group_indices = None
if ref_channel_ids is not None:
ref_channel_indices = self.ids_to_indices(ref_channel_ids)
else:
ref_channel_indices = None
assert int(n_workers) >= 1, "n_workers must be >= 1"
for parent_segment in recording.segments:
rec_segment = CommonReferenceRecordingSegment(
parent_segment,
reference,
operator,
group_indices,
ref_channel_indices,
local_kernel,
dtype_,
n_workers=int(n_workers),
)
self.add_recording_segment(rec_segment)
self._kwargs = dict(
recording=recording,
reference=reference,
groups=groups,
operator=operator,
ref_channel_ids=ref_channel_ids,
local_radius=local_radius,
min_local_neighbors=min_local_neighbors,
dtype=dtype_.str,
n_workers=int(n_workers),
)
class CommonReferenceRecordingSegment(BasePreprocessorSegment):
def __init__(
self,
parent_recording_segment,
reference,
operator,
group_indices,
ref_channel_indices,
local_kernel,
dtype,
n_workers=1,
):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
self.reference = reference
self.operator = operator
self.group_indices = group_indices
self.ref_channel_indices = ref_channel_indices
self.local_kernel = local_kernel
self.temp = None
self.dtype = dtype
self.operator = operator
self.operator_func = np.mean if self.operator == "average" else np.median
self.n_workers = int(n_workers)
# Per-caller-thread lazy pool map. See filter.FilterRecordingSegment
# for full rationale and WeakKeyDictionary mechanics.
self._cmr_pools = weakref.WeakKeyDictionary()
self._cmr_pools_lock = threading.Lock()
def _get_pool(self):
"""Lazy per-caller-thread thread pool for parallel median/mean across time blocks."""
if self.n_workers <= 1:
return None
thread = threading.current_thread()
pool = self._cmr_pools.get(thread)
if pool is None:
with self._cmr_pools_lock:
pool = self._cmr_pools.get(thread)
if pool is None:
from concurrent.futures import ThreadPoolExecutor
pool = ThreadPoolExecutor(max_workers=self.n_workers)
self._cmr_pools[thread] = pool
weakref.finalize(thread, pool.shutdown, wait=False)
return pool
def _parallel_reduce_axis1(self, traces):
"""Apply ``operator_func(..., axis=1)`` split across time blocks.
numpy's partition-based median and BLAS-backed mean release the GIL
during per-row work, so Python-thread parallelism delivers real
speedup (measured ~10× on 16 threads for 1M × 384 median).
"""
if self.n_workers == 1:
return self.operator_func(traces, axis=1)
T = traces.shape[0]
# Minimum block size per worker: below this, per-thread overhead
# outweighs the parallelism gain.
min_block = 8192
effective = max(1, min(self.n_workers, T // min_block))
if effective == 1:
return self.operator_func(traces, axis=1)
pool = self._get_pool()
block = (T + effective - 1) // effective
bounds = [(t0, min(t0 + block, T)) for t0 in range(0, T, block)]
def _work(t0, t1):
return t0, t1, self.operator_func(traces[t0:t1, :], axis=1)
futures = [pool.submit(_work, t0, t1) for t0, t1 in bounds]
results = [fut.result() for fut in futures]
out_dtype = results[0][2].dtype
out = np.empty(T, dtype=out_dtype)
for t0, t1, block_out in results:
out[t0:t1] = block_out
return out
def get_traces(self, start_frame, end_frame, channel_indices):
# Let's do the case with group_indices equal None as that is easy
if self.group_indices is None:
# We need all the channels to calculate the reference
traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None))
if self.reference == "global":
if self.ref_channel_indices is None:
# Hot path: parallelizable global median/mean across all channels.
shift = self._parallel_reduce_axis1(traces)[:, np.newaxis]
else:
shift = self.operator_func(traces[:, self.ref_channel_indices], axis=1, keepdims=True)
re_referenced_traces = traces[:, channel_indices] - shift
elif self.reference == "single":
# single channel -> no need of operator
shift = traces[:, self.ref_channel_indices]
re_referenced_traces = traces[:, channel_indices] - shift
else: # then it must be local
if self.operator == "median":
channel_indices_array = np.arange(traces.shape[1])[channel_indices]
re_referenced_traces = np.zeros((traces.shape[0], len(channel_indices_array)), dtype="float32")
for i, channel_index in enumerate(channel_indices_array):
channel_neighborhood = np.nonzero(self.local_kernel[channel_index])[0]
channel_shift = self.operator_func(traces[:, channel_neighborhood], axis=1)
re_referenced_traces[:, i] = traces[:, channel_index] - channel_shift
else: # then it must be local average, use local_kernel
re_referenced_traces = (
traces[:, channel_indices] - traces.dot(self.local_kernel.T)[:, channel_indices]
)
return re_referenced_traces.astype(self.dtype, copy=False)
# Then the old implementation for backwards compatibility that supports grouping
else:
# need input trace
traces = self.parent_recording_segment.get_traces(start_frame, end_frame, slice(None))
sliced_channel_indices = np.arange(traces.shape[1])
if channel_indices is not None:
sliced_channel_indices = sliced_channel_indices[channel_indices]
re_referenced_traces = np.zeros((traces.shape[0], sliced_channel_indices.size))
for group_index, selected_indices_in_group, all_group_indices in self.slice_groups(sliced_channel_indices):
(out_indices,) = np.nonzero(np.isin(sliced_channel_indices, selected_indices_in_group))
in_group_traces = traces[:, selected_indices_in_group]
if self.reference == "global":
shift = self.operator_func(traces[:, all_group_indices], axis=1, keepdims=True)
re_referenced_traces[:, out_indices] = in_group_traces - shift
else:
# single (as local is not allowed for groups)
shift = self.operator_func(
traces[:, [self.ref_channel_indices[group_index]]], axis=1, keepdims=True
)
re_referenced_traces[:, out_indices] = in_group_traces - shift
return re_referenced_traces.astype(self.dtype, copy=False)
def slice_groups(self, channel_indices):
"""
Slice the channel indices into groups. This is used to apply the common reference to groups of channels.
Parameters
----------
channel_indices : array-like
The channel indices to be sliced
Returns
-------
zip with:
* group_index: The index of the group
* selected_channels: The selected channel indices in the group
* group_channels: The channels indices in the group
"""
selected_channels = []
group_channels = []
group_indices = []
assert self.group_indices is not None, "No groups to slice"
for group_index, chanel_indices in enumerate(self.group_indices):
selected_indices = [ind for ind in channel_indices if ind in chanel_indices]
# if no channels are in a group, do not return the group
if len(selected_indices) > 0:
group_channels.append(chanel_indices)
selected_channels.append(selected_indices)
group_indices.append(group_index)
return zip(group_indices, selected_channels, group_channels)
common_reference = define_function_handling_dict_from_class(
source_class=CommonReferenceRecording, name="common_reference"
)