Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
db8849e
add `remove_bad_channels` and `RemoveBadChannelsRecording`
chrishalcrow Feb 12, 2025
5138e1f
add test and propogate the bad_channel_ids
chrishalcrow Feb 12, 2025
ef0cc53
Respond to Joe
chrishalcrow Feb 12, 2025
719463a
Merge branch 'main' into add-RemoveBadChannel-class
chrishalcrow Mar 19, 2025
a5e8af5
fix bug
chrishalcrow Mar 19, 2025
10c5d41
add `detect_and_interpolate_bad_channels`
chrishalcrow Mar 19, 2025
2a199e9
update `preprocessinglist`
chrishalcrow Mar 19, 2025
6c2fdee
refactor kwarg update
chrishalcrow Mar 20, 2025
3320d00
Merge branch 'main' into add-RemoveBadChannel-class
chrishalcrow Mar 20, 2025
6b07bb4
Merge branch 'main' into add-RemoveBadChannel-class
chrishalcrow Apr 1, 2025
4270885
add _precomputable_kwarg_names
chrishalcrow Apr 1, 2025
658a015
recording -> parent_recording
chrishalcrow Apr 3, 2025
437669f
Merge branch 'main' into add-RemoveBadChannel-class
chrishalcrow Apr 3, 2025
4036938
Merge branch 'main' into add-RemoveBadChannel-class
chrishalcrow Apr 8, 2025
badad10
Merge branch 'main' into add-RemoveBadChannel-class
chrishalcrow Apr 8, 2025
baf887c
Merge branch 'main' into add-RemoveBadChannel-class
chrishalcrow Apr 30, 2025
1f1d8c8
remove _main_ids from detect_and_interpolate
chrishalcrow Apr 30, 2025
166d722
Merge branch 'main' into add-RemoveBadChannel-class
chrishalcrow Apr 30, 2025
291993f
channel_labels for prov
chrishalcrow May 28, 2025
a959b35
Merge branch 'main' into add-RemoveBadChannel-class
chrishalcrow May 28, 2025
56da71c
add some docs
chrishalcrow May 28, 2025
02540f4
Merge branch 'main' into add-RemoveBadChannel-class
chrishalcrow Jun 2, 2025
bbe2a53
Fix conflicts
alejoe91 Jun 12, 2025
a2a6f18
Update src/spikeinterface/preprocessing/detect_bad_channels.py
chrishalcrow Jun 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 120 additions & 65 deletions src/spikeinterface/preprocessing/detect_bad_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,122 @@
import numpy as np
from typing import Literal

from .filter import highpass_filter
from ..core import get_random_data_chunks, order_channels_by_depth, BaseRecording
from spikeinterface.core.core_tools import define_function_from_class
from spikeinterface.preprocessing.filter import highpass_filter
from spikeinterface.core import get_random_data_chunks, order_channels_by_depth, BaseRecording
from spikeinterface.core.channelslice import ChannelSliceRecording

from inspect import signature


class RemoveBadChannelsRecording(ChannelSliceRecording):
"""
Removes bad channels.

{}

Returns
-------
removed_bad_channels_recording : RemoveBadChannelsRecording
The recording with bad channels removed
"""

params_doc = """Different methods are implemented:
Comment thread
chrishalcrow marked this conversation as resolved.
Outdated

* std : threhshold on channel standard deviations
If the standard deviation of a channel is greater than `std_mad_threshold` times the median of all
channels standard deviations, the channel is flagged as noisy
* mad : same as std, but using median absolute deviations instead
* coeherence+psd : method developed by the International Brain Laboratory that detects bad channels of three types:
* Dead channels are those with low similarity to the surrounding channels (n=`n_neighbors` median)
* Noise channels are those with power at >80% Nyquist above the psd_hf_threshold (default 0.02 uV^2 / Hz)
and a high coherence with "far away" channels"
* Out of brain channels are contigious regions of channels dissimilar to the median of all channels
at the top end of the probe (i.e. large channel number)
* neighborhood_r2
A method tuned for LFP use-cases, where channels should be highly correlated with their spatial
neighbors. This method estimates the correlation of each channel with the median of its spatial
neighbors, and considers channels bad when this correlation is too small.

Parameters
----------
recording : BaseRecording
The recording for which bad channels are detected
method : "coeherence+psd" | "std" | "mad" | "neighborhood_r2", default: "coeherence+psd"
The method to be used for bad channel detection
std_mad_threshold : float, default: 5
The standard deviation/mad multiplier threshold
psd_hf_threshold : float, default: 0.02
For coherence+psd - an absolute threshold (uV^2/Hz) used as a cutoff for noise channels.
Channels with average power at >80% Nyquist larger than this threshold
will be labeled as noise
dead_channel_threshold : float, default: -0.5
For coherence+psd - threshold for channel coherence below which channels are labeled as dead
noisy_channel_threshold : float, default: 1
Threshold for channel coherence above which channels are labeled as noisy (together with psd condition)
outside_channel_threshold : float, default: -0.75
For coherence+psd - threshold for channel coherence above which channels at the edge of the recording are marked as outside
of the brain
outside_channels_location : "top" | "bottom" | "both", default: "top"
For coherence+psd - location of the outside channels. If "top", only the channels at the top of the probe can be
marked as outside channels. If "bottom", only the channels at the bottom of the probe can be
marked as outside channels. If "both", both the channels at the top and bottom of the probe can be
marked as outside channels
n_neighbors : int, default: 11
For coeherence+psd - number of channel neighbors to compute median filter (needs to be odd)
nyquist_threshold : float, default: 0.8
For coherence+psd - frequency with respect to Nyquist (Fn=1) above which the mean of the PSD is calculated and compared
with psd_hf_threshold
direction : "x" | "y" | "z", default: "y"
For coherence+psd - the depth dimension
highpass_filter_cutoff : float, default: 300
If the recording is not filtered, the cutoff frequency of the highpass filter
chunk_duration_s : float, default: 0.5
Duration of each chunk
num_random_chunks : int, default: 100
Number of random chunks
Having many chunks is important for reproducibility.
welch_window_ms : float, default: 10
Window size for the scipy.signal.welch that will be converted to nperseg
neighborhood_r2_threshold : float, default: 0.95
R^2 threshold for the neighborhood_r2 method.
neighborhood_r2_radius_um : float, default: 30
Spatial radius below which two channels are considered neighbors in the neighborhood_r2 method.
seed : int or None, default: None
The random seed to extract chunks

"""

def __init__(
self,
recording: BaseRecording,
**detect_bad_channels_kwargs,
):

# get the default parameters from `detect_bad_channels`, and update with any user-specified parameters.
sig = signature(detect_bad_channels)
updated_detect_bad_channels_kwargs = {k: v.default for k, v in sig.parameters.items() if k != "recording"}
updated_detect_bad_channels_kwargs.update(detect_bad_channels_kwargs)

bad_channel_ids, channel_labels = detect_bad_channels(recording=recording, **detect_bad_channels_kwargs)
Comment thread
chrishalcrow marked this conversation as resolved.
Outdated

self._main_ids = recording.get_channel_ids()
new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, bad_channel_ids)]

ChannelSliceRecording.__init__(
self,
recording,
channel_ids=new_channel_ids,
)

self._kwargs.update({"bad_channel_ids": bad_channel_ids, "channel_labels": channel_labels})
self._kwargs.update(updated_detect_bad_channels_kwargs)


remove_bad_channels = define_function_from_class(source_class=RemoveBadChannelsRecording, name="remove_bad_channels")


RemoveBadChannelsRecording.__doc__ = RemoveBadChannelsRecording.__doc__.format(RemoveBadChannelsRecording.params_doc)


def detect_bad_channels(
Expand All @@ -32,69 +146,7 @@ def detect_bad_channels(
Perform bad channel detection.
The recording is assumed to be filtered. If not, a highpass filter is applied on the fly.

Different methods are implemented:

* std : threhshold on channel standard deviations
If the standard deviation of a channel is greater than `std_mad_threshold` times the median of all
channels standard deviations, the channel is flagged as noisy
* mad : same as std, but using median absolute deviations instead
* coeherence+psd : method developed by the International Brain Laboratory that detects bad channels of three types:
* Dead channels are those with low similarity to the surrounding channels (n=`n_neighbors` median)
* Noise channels are those with power at >80% Nyquist above the psd_hf_threshold (default 0.02 uV^2 / Hz)
and a high coherence with "far away" channels"
* Out of brain channels are contigious regions of channels dissimilar to the median of all channels
at the top end of the probe (i.e. large channel number)
* neighborhood_r2
A method tuned for LFP use-cases, where channels should be highly correlated with their spatial
neighbors. This method estimates the correlation of each channel with the median of its spatial
neighbors, and considers channels bad when this correlation is too small.

Parameters
----------
recording : BaseRecording
The recording for which bad channels are detected
method : "coeherence+psd" | "std" | "mad" | "neighborhood_r2", default: "coeherence+psd"
The method to be used for bad channel detection
std_mad_threshold : float, default: 5
The standard deviation/mad multiplier threshold
psd_hf_threshold : float, default: 0.02
For coherence+psd - an absolute threshold (uV^2/Hz) used as a cutoff for noise channels.
Channels with average power at >80% Nyquist larger than this threshold
will be labeled as noise
dead_channel_threshold : float, default: -0.5
For coherence+psd - threshold for channel coherence below which channels are labeled as dead
noisy_channel_threshold : float, default: 1
Threshold for channel coherence above which channels are labeled as noisy (together with psd condition)
outside_channel_threshold : float, default: -0.75
For coherence+psd - threshold for channel coherence above which channels at the edge of the recording are marked as outside
of the brain
outside_channels_location : "top" | "bottom" | "both", default: "top"
For coherence+psd - location of the outside channels. If "top", only the channels at the top of the probe can be
marked as outside channels. If "bottom", only the channels at the bottom of the probe can be
marked as outside channels. If "both", both the channels at the top and bottom of the probe can be
marked as outside channels
n_neighbors : int, default: 11
For coeherence+psd - number of channel neighbors to compute median filter (needs to be odd)
nyquist_threshold : float, default: 0.8
For coherence+psd - frequency with respect to Nyquist (Fn=1) above which the mean of the PSD is calculated and compared
with psd_hf_threshold
direction : "x" | "y" | "z", default: "y"
For coherence+psd - the depth dimension
highpass_filter_cutoff : float, default: 300
If the recording is not filtered, the cutoff frequency of the highpass filter
chunk_duration_s : float, default: 0.5
Duration of each chunk
num_random_chunks : int, default: 100
Number of random chunks
Having many chunks is important for reproducibility.
welch_window_ms : float, default: 10
Window size for the scipy.signal.welch that will be converted to nperseg
neighborhood_r2_threshold : float, default: 0.95
R^2 threshold for the neighborhood_r2 method.
neighborhood_r2_radius_um : float, default: 30
Spatial radius below which two channels are considered neighbors in the neighborhood_r2 method.
seed : int or None, default: None
The random seed to extract chunks
{}
Comment thread
chrishalcrow marked this conversation as resolved.

Returns
-------
Expand Down Expand Up @@ -269,6 +321,9 @@ def detect_bad_channels(
return bad_channel_ids, channel_labels


detect_bad_channels.__doc__ = detect_bad_channels.__doc__.format(RemoveBadChannelsRecording.params_doc)


# ----------------------------------------------------------------------------------------------
# IBL Detect Bad Channels
# ----------------------------------------------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/preprocessing/preprocessinglist.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .deepinterpolation import DeepInterpolatedRecording, deepinterpolate, train_deepinterpolation
from .highpass_spatial_filter import HighpassSpatialFilterRecording, highpass_spatial_filter
from .interpolate_bad_channels import InterpolateBadChannelsRecording, interpolate_bad_channels
from .detect_bad_channels import RemoveBadChannelsRecording, remove_bad_channels
from .average_across_direction import AverageAcrossDirectionRecording, average_across_direction
from .directional_derivative import DirectionalDerivativeRecording, directional_derivative
from .depth_order import DepthOrderRecording, depth_order
Expand Down Expand Up @@ -79,6 +80,7 @@
DirectionalDerivativeRecording,
AstypeRecording,
UnsignedToSignedRecording,
RemoveBadChannelsRecording,
]

preprocesser_dict = {pp_class.name: pp_class for pp_class in preprocessers_full_list}
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from spikeinterface import NumpyRecording, get_random_data_chunks
from probeinterface import generate_linear_probe

from spikeinterface.generation import generate_recording

from spikeinterface.core import generate_recording
from spikeinterface.preprocessing import detect_bad_channels, highpass_filter
from spikeinterface.preprocessing import detect_bad_channels, highpass_filter, remove_bad_channels

try:
# WARNING : this is not this package https://pypi.org/project/neurodsp/
Expand All @@ -18,6 +20,45 @@
HAVE_NPIX = False


def test_remove_bad_channel():
"""
Generate a recording, then remove bad channels with a low noise threshold,
so that some units are removed. Then check that the new recording has none
of the bad channels still in it and that the one changed kwarg is successfully
propogated to the new recording.
"""

recording = generate_recording(durations=[5, 6], seed=1205, num_channels=8)
recording.set_channel_offsets(0)
recording.set_channel_gains(1)

# set noisy_channel_threshold so that we do detect some bad channels
new_rec = remove_bad_channels(recording, noisy_channel_threshold=0, seed=1205)

# make sure they are removed
bad_channel_ids = new_rec._kwargs["bad_channel_ids"]
assert len(set(bad_channel_ids).intersection(new_rec.channel_ids)) == 0
# and the good ones are kept
good_channel_ids = recording.channel_ids[~np.isin(recording.channel_ids, bad_channel_ids)]
assert set(good_channel_ids) == set(new_rec.channel_ids)

# and that the kwarg is propogatged to the kwargs of new_rec.
assert set(new_rec._kwargs["channel_ids"]) == set(good_channel_ids)
assert new_rec._kwargs["noisy_channel_threshold"] == 0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything to worry about channel ordering here? I guess not (and this is a question for ChannelSliceRecording tests anyways. But as I have no understanding of how the ordering works I thought worth asking 😆

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ordering of the channel ids?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

erm like order of the channels on the recording itself (I'm not sure how exactly this is represented 😅 ). But like the default order when you do plot_traces without order_channel_by_depth


# now apply `detec_bad_channels` directly and see that the outputs matches
bad_channel_ids_from_function, channel_labels_from_function = detect_bad_channels(
recording, noisy_channel_threshold=0, seed=1205
)

assert np.all(new_rec._kwargs["bad_channel_ids"] == bad_channel_ids_from_function)
assert np.all(new_rec._kwargs["channel_labels"] == channel_labels_from_function)

new_rec_from_function = recording.remove_channels(remove_channel_ids=bad_channel_ids_from_function)

assert np.all(new_rec_from_function.channel_ids == new_rec.channel_ids)


def test_detect_bad_channels_std_mad():
num_channels = 4
sampling_frequency = 30000.0
Expand Down