Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
db8849e
add `remove_bad_channels` and `RemoveBadChannelsRecording`
chrishalcrow Feb 12, 2025
5138e1f
add test and propogate the bad_channel_ids
chrishalcrow Feb 12, 2025
ef0cc53
Respond to Joe
chrishalcrow Feb 12, 2025
719463a
Merge branch 'main' into add-RemoveBadChannel-class
chrishalcrow Mar 19, 2025
a5e8af5
fix bug
chrishalcrow Mar 19, 2025
10c5d41
add `detect_and_interpolate_bad_channels`
chrishalcrow Mar 19, 2025
2a199e9
update `preprocessinglist`
chrishalcrow Mar 19, 2025
6c2fdee
refactor kwarg update
chrishalcrow Mar 20, 2025
3320d00
Merge branch 'main' into add-RemoveBadChannel-class
chrishalcrow Mar 20, 2025
6b07bb4
Merge branch 'main' into add-RemoveBadChannel-class
chrishalcrow Apr 1, 2025
4270885
add _precomputable_kwarg_names
chrishalcrow Apr 1, 2025
658a015
recording -> parent_recording
chrishalcrow Apr 3, 2025
437669f
Merge branch 'main' into add-RemoveBadChannel-class
chrishalcrow Apr 3, 2025
4036938
Merge branch 'main' into add-RemoveBadChannel-class
chrishalcrow Apr 8, 2025
badad10
Merge branch 'main' into add-RemoveBadChannel-class
chrishalcrow Apr 8, 2025
baf887c
Merge branch 'main' into add-RemoveBadChannel-class
chrishalcrow Apr 30, 2025
1f1d8c8
remove _main_ids from detect_and_interpolate
chrishalcrow Apr 30, 2025
166d722
Merge branch 'main' into add-RemoveBadChannel-class
chrishalcrow Apr 30, 2025
291993f
channel_labels for prov
chrishalcrow May 28, 2025
a959b35
Merge branch 'main' into add-RemoveBadChannel-class
chrishalcrow May 28, 2025
56da71c
add some docs
chrishalcrow May 28, 2025
02540f4
Merge branch 'main' into add-RemoveBadChannel-class
chrishalcrow Jun 2, 2025
bbe2a53
Fix conflicts
alejoe91 Jun 12, 2025
a2a6f18
Update src/spikeinterface/preprocessing/detect_bad_channels.py
chrishalcrow Jun 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/how_to/process_by_channel_group.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,13 @@ to any preprocessing function.
shifted_recordings = spre.phase_shift(split_recording_dict)
filtered_recording = spre.bandpass_filter(shifted_recording)
referenced_recording = spre.common_reference(filtered_recording)
good_channels_recording = spre.detect_and_remove_bad_channels(filtered_recording)

We can then aggregate the recordings back together using the ``aggregate_channels`` function

.. code-block:: python

combined_preprocessed_recording = aggregate_channels(referenced_recording)
combined_preprocessed_recording = aggregate_channels(good_channels_recording)

Now, when ``combined_preprocessed_recording`` is used in sorting, plotting, or whenever
calling its :py:func:`~get_traces` method, the data will have been
Expand Down
12 changes: 12 additions & 0 deletions doc/modules/preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,18 @@ interpolated with the :code:`interpolate_bad_channels()` function (channels labe
# Case 2 : interpolate then
rec_clean = interpolate_bad_channels(recording=rec, bad_channel_ids=bad_channel_ids)

Once you have tested these functions and decided on your workflow, you can use the `detect_and_*`
functions to do everything at once. These return a Preprocessor class, so are consistent with
the "chain" concept for this module. For example:

.. code-block:: python

# detect and remove bad channels
rec_only_good_channels = detect_and_remove_bad_channels(recording=rec)

# detect and interpolate the bad channels
rec_interpolated_channels = detect_and_interpolate_bad_channels(recording=rec)


* :py:func:`~spikeinterface.preprocessing.detect_bad_channels()`
* :py:func:`~spikeinterface.preprocessing.interpolate_bad_channels()`
Expand Down
207 changes: 144 additions & 63 deletions src/spikeinterface/preprocessing/detect_bad_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,148 @@
import numpy as np
from typing import Literal

from spikeinterface.core.core_tools import define_function_handling_dict_from_class
from .filter import highpass_filter
from spikeinterface.core import get_random_data_chunks, order_channels_by_depth, BaseRecording
from spikeinterface.core.channelslice import ChannelSliceRecording

from inspect import signature

_bad_channel_detection_kwargs_doc = """Different methods are implemented:

* std : threhshold on channel standard deviations
If the standard deviation of a channel is greater than `std_mad_threshold` times the median of all
channels standard deviations, the channel is flagged as noisy
* mad : same as std, but using median absolute deviations instead
* coeherence+psd : method developed by the International Brain Laboratory that detects bad channels of three types:
* Dead channels are those with low similarity to the surrounding channels (n=`n_neighbors` median)
* Noise channels are those with power at >80% Nyquist above the psd_hf_threshold (default 0.02 uV^2 / Hz)
and a high coherence with "far away" channels"
* Out of brain channels are contigious regions of channels dissimilar to the median of all channels
at the top end of the probe (i.e. large channel number)
* neighborhood_r2
A method tuned for LFP use-cases, where channels should be highly correlated with their spatial
neighbors. This method estimates the correlation of each channel with the median of its spatial
neighbors, and considers channels bad when this correlation is too small.

Parameters
----------
recording : BaseRecording
The recording for which bad channels are detected
method : "coeherence+psd" | "std" | "mad" | "neighborhood_r2", default: "coeherence+psd"
The method to be used for bad channel detection
std_mad_threshold : float, default: 5
The standard deviation/mad multiplier threshold
psd_hf_threshold : float, default: 0.02
For coherence+psd - an absolute threshold (uV^2/Hz) used as a cutoff for noise channels.
Channels with average power at >80% Nyquist larger than this threshold
will be labeled as noise
dead_channel_threshold : float, default: -0.5
For coherence+psd - threshold for channel coherence below which channels are labeled as dead
noisy_channel_threshold : float, default: 1
Threshold for channel coherence above which channels are labeled as noisy (together with psd condition)
outside_channel_threshold : float, default: -0.75
For coherence+psd - threshold for channel coherence above which channels at the edge of the recording are marked as outside
of the brain
outside_channels_location : "top" | "bottom" | "both", default: "top"
For coherence+psd - location of the outside channels. If "top", only the channels at the top of the probe can be
marked as outside channels. If "bottom", only the channels at the bottom of the probe can be
marked as outside channels. If "both", both the channels at the top and bottom of the probe can be
marked as outside channels
n_neighbors : int, default: 11
For coeherence+psd - number of channel neighbors to compute median filter (needs to be odd)
nyquist_threshold : float, default: 0.8
For coherence+psd - frequency with respect to Nyquist (Fn=1) above which the mean of the PSD is calculated and compared
with psd_hf_threshold
direction : "x" | "y" | "z", default: "y"
For coherence+psd - the depth dimension
highpass_filter_cutoff : float, default: 300
If the recording is not filtered, the cutoff frequency of the highpass filter
chunk_duration_s : float, default: 0.5
Duration of each chunk
num_random_chunks : int, default: 100
Number of random chunks
Having many chunks is important for reproducibility.
welch_window_ms : float, default: 10
Window size for the scipy.signal.welch that will be converted to nperseg
neighborhood_r2_threshold : float, default: 0.95
R^2 threshold for the neighborhood_r2 method.
neighborhood_r2_radius_um : float, default: 30
Spatial radius below which two channels are considered neighbors in the neighborhood_r2 method.
seed : int or None, default: None
The random seed to extract chunks
"""


class DetectAndRemoveBadChannelsRecording(ChannelSliceRecording):
"""
Detects and removes bad channels. If `bad_channel_ids` are given,
the detection is skipped and uses these instead.

{}
bad_channel_ids : np.array | list | None, default: None
If given, these are used rather than being detected.
channel_labels : np.array | list | None, default: None
If given, these are labels given to the channels by the
detection process. Only intended for use when loading.

Returns
-------
removed_bad_channels_recording : DetectAndRemoveBadChannelsRecording
The recording with bad channels removed
"""

_precomputable_kwarg_names = ["bad_channel_ids"]
Comment thread
chrishalcrow marked this conversation as resolved.
Outdated

def __init__(
self,
parent_recording: BaseRecording,
bad_channel_ids=None,
channel_labels=None,
**detect_bad_channels_kwargs,
):

if bad_channel_ids is None:
bad_channel_ids, channel_labels = detect_bad_channels(
recording=parent_recording, **detect_bad_channels_kwargs
)
else:
channel_labels = None

self._main_ids = parent_recording.get_channel_ids()
new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, bad_channel_ids)]

ChannelSliceRecording.__init__(
self,
parent_recording=parent_recording,
channel_ids=new_channel_ids,
)

self._kwargs.update({"bad_channel_ids": bad_channel_ids})
if channel_labels is not None:
self._kwargs.update({"channel_labels": channel_labels})

all_bad_channels_kwargs = _get_all_detect_bad_channel_kwargs(detect_bad_channels_kwargs)
self._kwargs.update(all_bad_channels_kwargs)


detect_and_remove_bad_channels = define_function_handling_dict_from_class(
source_class=DetectAndRemoveBadChannelsRecording, name="detect_and_remove_bad_channels"
)
DetectAndRemoveBadChannelsRecording.__doc__ = DetectAndRemoveBadChannelsRecording.__doc__.format(
_bad_channel_detection_kwargs_doc
)


def _get_all_detect_bad_channel_kwargs(detect_bad_channels_kwargs):
"""Get the default parameters from `detect_bad_channels`, and update with any user-specified parameters."""

sig = signature(detect_bad_channels)
all_detect_bad_channels_kwargs = {
k: v.default for k, v in sig.parameters.items() if k not in ["recording", "parent_recording"]
}
all_detect_bad_channels_kwargs.update(detect_bad_channels_kwargs)
return all_detect_bad_channels_kwargs


def detect_bad_channels(
Expand All @@ -32,69 +172,7 @@ def detect_bad_channels(
Perform bad channel detection.
The recording is assumed to be filtered. If not, a highpass filter is applied on the fly.

Different methods are implemented:

* std : threhshold on channel standard deviations
If the standard deviation of a channel is greater than `std_mad_threshold` times the median of all
channels standard deviations, the channel is flagged as noisy
* mad : same as std, but using median absolute deviations instead
* coeherence+psd : method developed by the International Brain Laboratory that detects bad channels of three types:
* Dead channels are those with low similarity to the surrounding channels (n=`n_neighbors` median)
* Noise channels are those with power at >80% Nyquist above the psd_hf_threshold (default 0.02 uV^2 / Hz)
and a high coherence with "far away" channels"
* Out of brain channels are contigious regions of channels dissimilar to the median of all channels
at the top end of the probe (i.e. large channel number)
* neighborhood_r2
A method tuned for LFP use-cases, where channels should be highly correlated with their spatial
neighbors. This method estimates the correlation of each channel with the median of its spatial
neighbors, and considers channels bad when this correlation is too small.

Parameters
----------
recording : BaseRecording
The recording for which bad channels are detected
method : "coeherence+psd" | "std" | "mad" | "neighborhood_r2", default: "coeherence+psd"
The method to be used for bad channel detection
std_mad_threshold : float, default: 5
The standard deviation/mad multiplier threshold
psd_hf_threshold : float, default: 0.02
For coherence+psd - an absolute threshold (uV^2/Hz) used as a cutoff for noise channels.
Channels with average power at >80% Nyquist larger than this threshold
will be labeled as noise
dead_channel_threshold : float, default: -0.5
For coherence+psd - threshold for channel coherence below which channels are labeled as dead
noisy_channel_threshold : float, default: 1
Threshold for channel coherence above which channels are labeled as noisy (together with psd condition)
outside_channel_threshold : float, default: -0.75
For coherence+psd - threshold for channel coherence above which channels at the edge of the recording are marked as outside
of the brain
outside_channels_location : "top" | "bottom" | "both", default: "top"
For coherence+psd - location of the outside channels. If "top", only the channels at the top of the probe can be
marked as outside channels. If "bottom", only the channels at the bottom of the probe can be
marked as outside channels. If "both", both the channels at the top and bottom of the probe can be
marked as outside channels
n_neighbors : int, default: 11
For coeherence+psd - number of channel neighbors to compute median filter (needs to be odd)
nyquist_threshold : float, default: 0.8
For coherence+psd - frequency with respect to Nyquist (Fn=1) above which the mean of the PSD is calculated and compared
with psd_hf_threshold
direction : "x" | "y" | "z", default: "y"
For coherence+psd - the depth dimension
highpass_filter_cutoff : float, default: 300
If the recording is not filtered, the cutoff frequency of the highpass filter
chunk_duration_s : float, default: 0.5
Duration of each chunk
num_random_chunks : int, default: 100
Number of random chunks
Having many chunks is important for reproducibility.
welch_window_ms : float, default: 10
Window size for the scipy.signal.welch that will be converted to nperseg
neighborhood_r2_threshold : float, default: 0.95
R^2 threshold for the neighborhood_r2 method.
neighborhood_r2_radius_um : float, default: 30
Spatial radius below which two channels are considered neighbors in the neighborhood_r2 method.
seed : int or None, default: None
The random seed to extract chunks
{}
Comment thread
chrishalcrow marked this conversation as resolved.

Returns
-------
Expand Down Expand Up @@ -269,6 +347,9 @@ def detect_bad_channels(
return bad_channel_ids, channel_labels


detect_bad_channels.__doc__ = detect_bad_channels.__doc__.format(_bad_channel_detection_kwargs_doc)


# ----------------------------------------------------------------------------------------------
# IBL Detect Bad Channels
# ----------------------------------------------------------------------------------------------
Expand Down
61 changes: 60 additions & 1 deletion src/spikeinterface/preprocessing/interpolate_bad_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@

import numpy as np

from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment, BaseRecording
from spikeinterface.core.core_tools import define_function_handling_dict_from_class
from spikeinterface.preprocessing import preprocessing_tools
from .detect_bad_channels import (
_bad_channel_detection_kwargs_doc,
detect_bad_channels,
_get_all_detect_bad_channel_kwargs,
)
from inspect import signature


class InterpolateBadChannelsRecording(BasePreprocessor):
Expand Down Expand Up @@ -82,6 +88,59 @@ def check_inputs(self, recording, bad_channel_ids):
raise NotImplementedError("Channel spacing units must be um")


class DetectAndInterpolateBadChannelsRecording(InterpolateBadChannelsRecording):
"""
Detects and interpolates bad channels. If `bad_channel_ids` are given,
the detection is skipped and uses these instead.

{}
bad_channel_ids : np.array | list | None, default: None
If given, these are used rather than being detected.
channel_labels : np.array | list | None, default: None
If given, these are labels given to the channels by the
detection process. Only intended for use when loading.

Returns
-------
interpolated_bad_channels_recording : DetectAndInterpolateBadChannelsRecording
The recording with bad channels removed
"""

_precomputable_kwarg_names = ["bad_channel_ids"]

def __init__(
self,
recording: BaseRecording,
bad_channel_ids=None,
**detect_bad_channels_kwargs,
):
if bad_channel_ids is None:
bad_channel_ids, channel_labels = detect_bad_channels(recording=recording, **detect_bad_channels_kwargs)
else:
channel_labels = None

InterpolateBadChannelsRecording.__init__(
self,
recording,
bad_channel_ids=bad_channel_ids,
)

self._kwargs.update({"bad_channel_ids": bad_channel_ids})
if channel_labels is not None:
self._kwargs.update({"channel_labels": channel_labels})

all_bad_channels_kwargs = _get_all_detect_bad_channel_kwargs(detect_bad_channels_kwargs)
self._kwargs.update(all_bad_channels_kwargs)


detect_and_interpolate_bad_channels = define_function_handling_dict_from_class(
source_class=DetectAndInterpolateBadChannelsRecording, name="detect_and_interpolate_bad_channels"
)
DetectAndInterpolateBadChannelsRecording.__doc__ = DetectAndInterpolateBadChannelsRecording.__doc__.format(
_bad_channel_detection_kwargs_doc
)


class InterpolateBadChannelsSegment(BasePreprocessorSegment):
def __init__(self, parent_recording_segment, good_channel_indices, bad_channel_indices, weights):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
Expand Down
11 changes: 10 additions & 1 deletion src/spikeinterface/preprocessing/preprocessing_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@
from .zero_channel_pad import ZeroChannelPaddedRecording, zero_channel_pad
from .deepinterpolation import DeepInterpolatedRecording, deepinterpolate, train_deepinterpolation
from .highpass_spatial_filter import HighpassSpatialFilterRecording, highpass_spatial_filter
from .interpolate_bad_channels import InterpolateBadChannelsRecording, interpolate_bad_channels
from .interpolate_bad_channels import (
DetectAndInterpolateBadChannelsRecording,
detect_and_interpolate_bad_channels,
InterpolateBadChannelsRecording,
interpolate_bad_channels,
)
from .detect_bad_channels import DetectAndRemoveBadChannelsRecording, detect_and_remove_bad_channels
from .average_across_direction import AverageAcrossDirectionRecording, average_across_direction
from .directional_derivative import DirectionalDerivativeRecording, directional_derivative
from .depth_order import DepthOrderRecording, depth_order
Expand All @@ -63,6 +69,9 @@
# re-reference
CommonReferenceRecording: common_reference,
PhaseShiftRecording: phase_shift,
# bad channel detection/interpolation
DetectAndRemoveBadChannelsRecording: detect_and_remove_bad_channels,
DetectAndInterpolateBadChannelsRecording: detect_and_interpolate_bad_channels,
# misc
RectifyRecording: rectify,
ClipRecording: clip,
Expand Down
Loading
Loading