From db8849e14d2d2caab7c007a652bbe13f6f561fae Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 12 Feb 2025 12:16:38 +0000 Subject: [PATCH 01/13] add `remove_bad_channels` and `RemoveBadChannelsRecording` --- .../preprocessing/detect_bad_channels.py | 185 ++++++++++++------ .../preprocessing/preprocessinglist.py | 2 + 2 files changed, 122 insertions(+), 65 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index 5d8f7107c7..e5ab347b3e 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -4,8 +4,122 @@ import numpy as np from typing import Literal -from .filter import highpass_filter -from ..core import get_random_data_chunks, order_channels_by_depth, BaseRecording +from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.preprocessing.filter import highpass_filter +from spikeinterface.core import get_random_data_chunks, order_channels_by_depth, BaseRecording +from spikeinterface.core.channelslice import ChannelSliceRecording + +from inspect import signature + + +class RemoveBadChannelsRecording(ChannelSliceRecording): + """ + Removes bad channels. + + {} + + Returns + ------- + removed_bad_channels_recording : RemoveBadChannelsRecording + The recording with bad channels removed + """ + + params_doc = """Different methods are implemented: + +* std : threhshold on channel standard deviations + If the standard deviation of a channel is greater than `std_mad_threshold` times the median of all + channels standard deviations, the channel is flagged as noisy +* mad : same as std, but using median absolute deviations instead +* coeherence+psd : method developed by the International Brain Laboratory that detects bad channels of three types: + * Dead channels are those with low similarity to the surrounding channels (n=`n_neighbors` median) + * Noise channels are those with power at >80% Nyquist above the psd_hf_threshold (default 0.02 uV^2 / Hz) + and a high coherence with "far away" channels" + * Out of brain channels are contigious regions of channels dissimilar to the median of all channels + at the top end of the probe (i.e. large channel number) +* neighborhood_r2 + A method tuned for LFP use-cases, where channels should be highly correlated with their spatial + neighbors. This method estimates the correlation of each channel with the median of its spatial + neighbors, and considers channels bad when this correlation is too small. + +Parameters +---------- +recording : BaseRecording + The recording for which bad channels are detected +method : "coeherence+psd" | "std" | "mad" | "neighborhood_r2", default: "coeherence+psd" + The method to be used for bad channel detection +std_mad_threshold : float, default: 5 + The standard deviation/mad multiplier threshold +psd_hf_threshold : float, default: 0.02 + For coherence+psd - an absolute threshold (uV^2/Hz) used as a cutoff for noise channels. + Channels with average power at >80% Nyquist larger than this threshold + will be labeled as noise +dead_channel_threshold : float, default: -0.5 + For coherence+psd - threshold for channel coherence below which channels are labeled as dead +noisy_channel_threshold : float, default: 1 + Threshold for channel coherence above which channels are labeled as noisy (together with psd condition) +outside_channel_threshold : float, default: -0.75 + For coherence+psd - threshold for channel coherence above which channels at the edge of the recording are marked as outside + of the brain +outside_channels_location : "top" | "bottom" | "both", default: "top" + For coherence+psd - location of the outside channels. If "top", only the channels at the top of the probe can be + marked as outside channels. If "bottom", only the channels at the bottom of the probe can be + marked as outside channels. If "both", both the channels at the top and bottom of the probe can be + marked as outside channels +n_neighbors : int, default: 11 + For coeherence+psd - number of channel neighbors to compute median filter (needs to be odd) +nyquist_threshold : float, default: 0.8 + For coherence+psd - frequency with respect to Nyquist (Fn=1) above which the mean of the PSD is calculated and compared + with psd_hf_threshold +direction : "x" | "y" | "z", default: "y" + For coherence+psd - the depth dimension +highpass_filter_cutoff : float, default: 300 + If the recording is not filtered, the cutoff frequency of the highpass filter +chunk_duration_s : float, default: 0.5 + Duration of each chunk +num_random_chunks : int, default: 100 + Number of random chunks + Having many chunks is important for reproducibility. +welch_window_ms : float, default: 10 + Window size for the scipy.signal.welch that will be converted to nperseg +neighborhood_r2_threshold : float, default: 0.95 + R^2 threshold for the neighborhood_r2 method. +neighborhood_r2_radius_um : float, default: 30 + Spatial radius below which two channels are considered neighbors in the neighborhood_r2 method. +seed : int or None, default: None + The random seed to extract chunks + + """ + + def __init__( + self, + recording: BaseRecording, + **detect_bad_channels_kwargs, + ): + + # get the default parameters from `detect_bad_channels`, and update with any user-specified parameters. + sig = signature(detect_bad_channels) + updated_detect_bad_channels_kwargs = {k: v.default for k, v in sig.parameters.items() if k != "recording"} + updated_detect_bad_channels_kwargs.update(detect_bad_channels_kwargs) + + bad_channel_ids, channel_labels = detect_bad_channels(recording=recording, **updated_detect_bad_channels_kwargs) + + self._main_ids = recording.get_channel_ids() + new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, bad_channel_ids)] + + ChannelSliceRecording.__init__( + self, + recording, + channel_ids=new_channel_ids, + ) + + self._kwargs.update({"channel_labels": channel_labels}) + self._kwargs.update(updated_detect_bad_channels_kwargs) + + +remove_bad_channels = define_function_from_class(source_class=RemoveBadChannelsRecording, name="remove_bad_channels") + + +RemoveBadChannelsRecording.__doc__ = RemoveBadChannelsRecording.__doc__.format(RemoveBadChannelsRecording.params_doc) def detect_bad_channels( @@ -32,69 +146,7 @@ def detect_bad_channels( Perform bad channel detection. The recording is assumed to be filtered. If not, a highpass filter is applied on the fly. - Different methods are implemented: - - * std : threhshold on channel standard deviations - If the standard deviation of a channel is greater than `std_mad_threshold` times the median of all - channels standard deviations, the channel is flagged as noisy - * mad : same as std, but using median absolute deviations instead - * coeherence+psd : method developed by the International Brain Laboratory that detects bad channels of three types: - * Dead channels are those with low similarity to the surrounding channels (n=`n_neighbors` median) - * Noise channels are those with power at >80% Nyquist above the psd_hf_threshold (default 0.02 uV^2 / Hz) - and a high coherence with "far away" channels" - * Out of brain channels are contigious regions of channels dissimilar to the median of all channels - at the top end of the probe (i.e. large channel number) - * neighborhood_r2 - A method tuned for LFP use-cases, where channels should be highly correlated with their spatial - neighbors. This method estimates the correlation of each channel with the median of its spatial - neighbors, and considers channels bad when this correlation is too small. - - Parameters - ---------- - recording : BaseRecording - The recording for which bad channels are detected - method : "coeherence+psd" | "std" | "mad" | "neighborhood_r2", default: "coeherence+psd" - The method to be used for bad channel detection - std_mad_threshold : float, default: 5 - The standard deviation/mad multiplier threshold - psd_hf_threshold : float, default: 0.02 - For coherence+psd - an absolute threshold (uV^2/Hz) used as a cutoff for noise channels. - Channels with average power at >80% Nyquist larger than this threshold - will be labeled as noise - dead_channel_threshold : float, default: -0.5 - For coherence+psd - threshold for channel coherence below which channels are labeled as dead - noisy_channel_threshold : float, default: 1 - Threshold for channel coherence above which channels are labeled as noisy (together with psd condition) - outside_channel_threshold : float, default: -0.75 - For coherence+psd - threshold for channel coherence above which channels at the edge of the recording are marked as outside - of the brain - outside_channels_location : "top" | "bottom" | "both", default: "top" - For coherence+psd - location of the outside channels. If "top", only the channels at the top of the probe can be - marked as outside channels. If "bottom", only the channels at the bottom of the probe can be - marked as outside channels. If "both", both the channels at the top and bottom of the probe can be - marked as outside channels - n_neighbors : int, default: 11 - For coeherence+psd - number of channel neighbors to compute median filter (needs to be odd) - nyquist_threshold : float, default: 0.8 - For coherence+psd - frequency with respect to Nyquist (Fn=1) above which the mean of the PSD is calculated and compared - with psd_hf_threshold - direction : "x" | "y" | "z", default: "y" - For coherence+psd - the depth dimension - highpass_filter_cutoff : float, default: 300 - If the recording is not filtered, the cutoff frequency of the highpass filter - chunk_duration_s : float, default: 0.5 - Duration of each chunk - num_random_chunks : int, default: 100 - Number of random chunks - Having many chunks is important for reproducibility. - welch_window_ms : float, default: 10 - Window size for the scipy.signal.welch that will be converted to nperseg - neighborhood_r2_threshold : float, default: 0.95 - R^2 threshold for the neighborhood_r2 method. - neighborhood_r2_radius_um : float, default: 30 - Spatial radius below which two channels are considered neighbors in the neighborhood_r2 method. - seed : int or None, default: None - The random seed to extract chunks + {} Returns ------- @@ -269,6 +321,9 @@ def detect_bad_channels( return bad_channel_ids, channel_labels +detect_bad_channels.__doc__ = detect_bad_channels.__doc__.format(RemoveBadChannelsRecording.params_doc) + + # ---------------------------------------------------------------------------------------------- # IBL Detect Bad Channels # ---------------------------------------------------------------------------------------------- diff --git a/src/spikeinterface/preprocessing/preprocessinglist.py b/src/spikeinterface/preprocessing/preprocessinglist.py index bdf5f2219c..13bdf8fe87 100644 --- a/src/spikeinterface/preprocessing/preprocessinglist.py +++ b/src/spikeinterface/preprocessing/preprocessinglist.py @@ -38,6 +38,7 @@ from .deepinterpolation import DeepInterpolatedRecording, deepinterpolate, train_deepinterpolation from .highpass_spatial_filter import HighpassSpatialFilterRecording, highpass_spatial_filter from .interpolate_bad_channels import InterpolateBadChannelsRecording, interpolate_bad_channels +from .detect_bad_channels import RemoveBadChannelsRecording, remove_bad_channels from .average_across_direction import AverageAcrossDirectionRecording, average_across_direction from .directional_derivative import DirectionalDerivativeRecording, directional_derivative from .depth_order import DepthOrderRecording, depth_order @@ -79,6 +80,7 @@ DirectionalDerivativeRecording, AstypeRecording, UnsignedToSignedRecording, + RemoveBadChannelsRecording, ] preprocesser_dict = {pp_class.name: pp_class for pp_class in preprocessers_full_list} From 5138e1f6038c412e3125a97073fc5f60099c9aee Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 12 Feb 2025 12:33:33 +0000 Subject: [PATCH 02/13] add test and propogate the bad_channel_ids --- .../preprocessing/detect_bad_channels.py | 2 +- .../tests/test_detect_bad_channels.py | 24 ++++++++++++++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index e5ab347b3e..f76c8b9aa6 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -112,7 +112,7 @@ def __init__( channel_ids=new_channel_ids, ) - self._kwargs.update({"channel_labels": channel_labels}) + self._kwargs.update({"bad_channel_ids": bad_channel_ids, "channel_labels": channel_labels}) self._kwargs.update(updated_detect_bad_channels_kwargs) diff --git a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py index 4622be1440..1d87ed8c2c 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py @@ -4,8 +4,10 @@ from spikeinterface import NumpyRecording, get_random_data_chunks from probeinterface import generate_linear_probe +from spikeinterface.generation import generate_recording + from spikeinterface.core import generate_recording -from spikeinterface.preprocessing import detect_bad_channels, highpass_filter +from spikeinterface.preprocessing import detect_bad_channels, highpass_filter, remove_bad_channels try: # WARNING : this is not this package https://pypi.org/project/neurodsp/ @@ -18,6 +20,26 @@ HAVE_NPIX = False +def test_remove_bad_channel(): + """ + Generate a recording, then remove bad channels with a low noise threshold, + so that some units are removed. Then check that the new recording has none + of the bad channels still in it and that the one changed kwarg is successfully + propogated to the new recording. + """ + + recording = generate_recording(durations=[5, 6], seed=1205, num_channels=8) + recording.set_channel_offsets(0) + recording.set_channel_gains(1) + # set noisy_channel_threshold so that we do detect some bad channels + new_rec = remove_bad_channels(recording, noisy_channel_threshold=0) + + # make sure they are removed + assert len(set(new_rec._kwargs["bad_channel_ids"]).intersection(new_rec.channel_ids)) == 0 + # and that the kwarg is propogatged to the kwargs of new_rec. + assert new_rec._kwargs["noisy_channel_threshold"] == 0 + + def test_detect_bad_channels_std_mad(): num_channels = 4 sampling_frequency = 30000.0 From ef0cc538adbe1015032b5282e4b365cf5daa2bf8 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 12 Feb 2025 15:01:01 +0000 Subject: [PATCH 03/13] Respond to Joe --- .../preprocessing/detect_bad_channels.py | 2 +- .../tests/test_detect_bad_channels.py | 23 +++++++++++++++++-- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index f76c8b9aa6..8d87df26bf 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -101,7 +101,7 @@ def __init__( updated_detect_bad_channels_kwargs = {k: v.default for k, v in sig.parameters.items() if k != "recording"} updated_detect_bad_channels_kwargs.update(detect_bad_channels_kwargs) - bad_channel_ids, channel_labels = detect_bad_channels(recording=recording, **updated_detect_bad_channels_kwargs) + bad_channel_ids, channel_labels = detect_bad_channels(recording=recording, **detect_bad_channels_kwargs) self._main_ids = recording.get_channel_ids() new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, bad_channel_ids)] diff --git a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py index 1d87ed8c2c..0815a71948 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py @@ -31,14 +31,33 @@ def test_remove_bad_channel(): recording = generate_recording(durations=[5, 6], seed=1205, num_channels=8) recording.set_channel_offsets(0) recording.set_channel_gains(1) + # set noisy_channel_threshold so that we do detect some bad channels - new_rec = remove_bad_channels(recording, noisy_channel_threshold=0) + new_rec = remove_bad_channels(recording, noisy_channel_threshold=0, seed=1205) # make sure they are removed - assert len(set(new_rec._kwargs["bad_channel_ids"]).intersection(new_rec.channel_ids)) == 0 + bad_channel_ids = new_rec._kwargs["bad_channel_ids"] + assert len(set(bad_channel_ids).intersection(new_rec.channel_ids)) == 0 + # and the good ones are kept + good_channel_ids = recording.channel_ids[~np.isin(recording.channel_ids, bad_channel_ids)] + assert set(good_channel_ids) == set(new_rec.channel_ids) + # and that the kwarg is propogatged to the kwargs of new_rec. + assert set(new_rec._kwargs["channel_ids"]) == set(good_channel_ids) assert new_rec._kwargs["noisy_channel_threshold"] == 0 + # now apply `detec_bad_channels` directly and see that the outputs matches + bad_channel_ids_from_function, channel_labels_from_function = detect_bad_channels( + recording, noisy_channel_threshold=0, seed=1205 + ) + + assert np.all(new_rec._kwargs["bad_channel_ids"] == bad_channel_ids_from_function) + assert np.all(new_rec._kwargs["channel_labels"] == channel_labels_from_function) + + new_rec_from_function = recording.remove_channels(remove_channel_ids=bad_channel_ids_from_function) + + assert np.all(new_rec_from_function.channel_ids == new_rec.channel_ids) + def test_detect_bad_channels_std_mad(): num_channels = 4 From a5e8af53ef9d0e2e0ee27b496313d3891059c05a Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Wed, 19 Mar 2025 10:50:18 +0000 Subject: [PATCH 04/13] fix bug --- .../preprocessing/detect_bad_channels.py | 54 +++++++++++-------- .../tests/test_detect_bad_channels.py | 4 +- 2 files changed, 34 insertions(+), 24 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index fe08413a49..c49cc48c60 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -4,27 +4,14 @@ import numpy as np from typing import Literal -from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.core_tools import define_function_handling_dict_from_class from .filter import highpass_filter from spikeinterface.core import get_random_data_chunks, order_channels_by_depth, BaseRecording from spikeinterface.core.channelslice import ChannelSliceRecording from inspect import signature - -class RemoveBadChannelsRecording(ChannelSliceRecording): - """ - Removes bad channels. - - {} - - Returns - ------- - removed_bad_channels_recording : RemoveBadChannelsRecording - The recording with bad channels removed - """ - - params_doc = """Different methods are implemented: +_bad_channel_detection_kwargs_doc = """Different methods are implemented: * std : threhshold on channel standard deviations If the standard deviation of a channel is greater than `std_mad_threshold` times the median of all @@ -87,12 +74,28 @@ class RemoveBadChannelsRecording(ChannelSliceRecording): Spatial radius below which two channels are considered neighbors in the neighborhood_r2 method. seed : int or None, default: None The random seed to extract chunks +""" + +class DetectAndRemoveBadChannelsRecording(ChannelSliceRecording): + """ + Detects and removes bad channels. If `bad_channel_ids` are given, + the detection is skipped and uses these instead. + + {} + bad_channel_ids : np.array | list | None, default: None + If given, these are used rather than being dected. + + Returns + ------- + removed_bad_channels_recording : DetectAndRemoveBadChannelsRecording + The recording with bad channels removed """ def __init__( self, recording: BaseRecording, + bad_channel_ids=None, **detect_bad_channels_kwargs, ): @@ -101,7 +104,10 @@ def __init__( updated_detect_bad_channels_kwargs = {k: v.default for k, v in sig.parameters.items() if k != "recording"} updated_detect_bad_channels_kwargs.update(detect_bad_channels_kwargs) - bad_channel_ids, channel_labels = detect_bad_channels(recording=recording, **detect_bad_channels_kwargs) + if bad_channel_ids is None: + bad_channel_ids, channel_labels = detect_bad_channels(recording=recording, **detect_bad_channels_kwargs) + else: + channel_labels = None self._main_ids = recording.get_channel_ids() new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, bad_channel_ids)] @@ -112,14 +118,18 @@ def __init__( channel_ids=new_channel_ids, ) - self._kwargs.update({"bad_channel_ids": bad_channel_ids, "channel_labels": channel_labels}) + self._kwargs.update({"bad_channel_ids": bad_channel_ids}) + if channel_labels is not None: + self._kwargs.update({"channel_labels": channel_labels}) self._kwargs.update(updated_detect_bad_channels_kwargs) -remove_bad_channels = define_function_from_class(source_class=RemoveBadChannelsRecording, name="remove_bad_channels") - - -RemoveBadChannelsRecording.__doc__ = RemoveBadChannelsRecording.__doc__.format(RemoveBadChannelsRecording.params_doc) +detect_and_remove_bad_channels = define_function_handling_dict_from_class( + source_class=DetectAndRemoveBadChannelsRecording, name="detect_and_remove_bad_channels" +) +DetectAndRemoveBadChannelsRecording.__doc__ = DetectAndRemoveBadChannelsRecording.__doc__.format( + _bad_channel_detection_kwargs_doc +) def detect_bad_channels( @@ -321,7 +331,7 @@ def detect_bad_channels( return bad_channel_ids, channel_labels -detect_bad_channels.__doc__ = detect_bad_channels.__doc__.format(RemoveBadChannelsRecording.params_doc) +detect_bad_channels.__doc__ = detect_bad_channels.__doc__.format(_bad_channel_detection_kwargs_doc) # ---------------------------------------------------------------------------------------------- diff --git a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py index 0815a71948..38bb35749e 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py @@ -7,7 +7,7 @@ from spikeinterface.generation import generate_recording from spikeinterface.core import generate_recording -from spikeinterface.preprocessing import detect_bad_channels, highpass_filter, remove_bad_channels +from spikeinterface.preprocessing import detect_bad_channels, highpass_filter, detect_and_remove_bad_channels try: # WARNING : this is not this package https://pypi.org/project/neurodsp/ @@ -33,7 +33,7 @@ def test_remove_bad_channel(): recording.set_channel_gains(1) # set noisy_channel_threshold so that we do detect some bad channels - new_rec = remove_bad_channels(recording, noisy_channel_threshold=0, seed=1205) + new_rec = detect_and_remove_bad_channels(recording, noisy_channel_threshold=0, seed=1205) # make sure they are removed bad_channel_ids = new_rec._kwargs["bad_channel_ids"] From 10c5d413f195ffc74ccb548c19bdf2f96012e409 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Wed, 19 Mar 2025 10:51:38 +0000 Subject: [PATCH 05/13] add `detect_and_interpolate_bad_channels` --- .../preprocessing/interpolate_bad_channels.py | 58 ++++++++++++++++++- .../tests/test_interpolate_bad_channels.py | 28 ++++++++- 2 files changed, 84 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/interpolate_bad_channels.py b/src/spikeinterface/preprocessing/interpolate_bad_channels.py index 87bb0f936b..8507cae1d7 100644 --- a/src/spikeinterface/preprocessing/interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/interpolate_bad_channels.py @@ -2,9 +2,11 @@ import numpy as np -from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment +from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment, BaseRecording from spikeinterface.core.core_tools import define_function_handling_dict_from_class from spikeinterface.preprocessing import preprocessing_tools +from .detect_bad_channels import _bad_channel_detection_kwargs_doc, detect_bad_channels +from inspect import signature class InterpolateBadChannelsRecording(BasePreprocessor): @@ -82,6 +84,60 @@ def check_inputs(self, recording, bad_channel_ids): raise NotImplementedError("Channel spacing units must be um") +class DetectAndInterpolateBadChannelsRecording(InterpolateBadChannelsRecording): + """ + Detects and interpolates bad channels. If `bad_channel_ids` are given, + the detection is skipped and uses these instead. + + {} + bad_channel_ids : np.array | list | None, default: None + If given, these are used rather than being dected. + + Returns + ------- + interpolated_bad_channels_recording : DetectAndInterpolateBadChannelsRecording + The recording with bad channels removed + """ + + def __init__( + self, + recording: BaseRecording, + bad_channel_ids=None, + **detect_bad_channels_kwargs, + ): + + # get the default parameters from `detect_bad_channels`, and update with any user-specified parameters. + sig = signature(detect_bad_channels) + updated_detect_bad_channels_kwargs = {k: v.default for k, v in sig.parameters.items() if k != "recording"} + updated_detect_bad_channels_kwargs.update(detect_bad_channels_kwargs) + + if bad_channel_ids is None: + bad_channel_ids, channel_labels = detect_bad_channels(recording=recording, **detect_bad_channels_kwargs) + else: + channel_labels = None + + self._main_ids = recording.get_channel_ids() + + InterpolateBadChannelsRecording.__init__( + self, + recording, + bad_channel_ids=bad_channel_ids, + ) + + self._kwargs.update({"bad_channel_ids": bad_channel_ids}) + if channel_labels is not None: + self._kwargs.update({"channel_labels": channel_labels}) + self._kwargs.update(updated_detect_bad_channels_kwargs) + + +detect_and_interpolate_bad_channels = define_function_handling_dict_from_class( + source_class=DetectAndInterpolateBadChannelsRecording, name="detect_and_interpolate_bad_channels" +) +DetectAndInterpolateBadChannelsRecording.__doc__ = DetectAndInterpolateBadChannelsRecording.__doc__.format( + _bad_channel_detection_kwargs_doc +) + + class InterpolateBadChannelsSegment(BasePreprocessorSegment): def __init__(self, parent_recording_segment, good_channel_indices, bad_channel_indices, weights): BasePreprocessorSegment.__init__(self, parent_recording_segment) diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index 06bde4e3d1..e05cccebd2 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -7,7 +7,8 @@ import spikeinterface.extractors as se from spikeinterface.core.generate import generate_recording import importlib.util - +from spikeinterface.preprocessing.interpolate_bad_channels import detect_and_interpolate_bad_channels +from spikeinterface.preprocessing.detect_bad_channels import detect_bad_channels ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) DEBUG = False @@ -23,6 +24,31 @@ # ------------------------------------------------------------------------------- +def test_detect_and_interpolate_bad_channel(): + """ + Generate a recording, then remove bad channels with a low noise threshold, so that + some units are removed. Then check that the new recording is an interpolated + recording and that kwargs are successfully propogated to the new recording. + """ + + recording = generate_recording(durations=[5, 6], seed=1205, num_channels=8) + recording.set_channel_offsets(0) + recording.set_channel_gains(1) + + # find the bad channels directly + bad_channel_ids, _ = detect_bad_channels(recording, noisy_channel_threshold=0, seed=1205) + + # set noisy_channel_threshold so that we do detect some bad channels + new_rec = detect_and_interpolate_bad_channels(recording, noisy_channel_threshold=0, seed=1205) + + # make sure they are in the new recording kwargs + bad_channel_ids_from_rec = new_rec._kwargs["bad_channel_ids"] + assert set(bad_channel_ids) == set(bad_channel_ids_from_rec) + + # and that the kwarg is propogatged to the kwargs of new_rec. + assert new_rec._kwargs["noisy_channel_threshold"] == 0 + + @pytest.mark.skipif( importlib.util.find_spec("neurodsp") is not None or importlib.util.find_spec("spikeglx") or ON_GITHUB, reason="Only local. Requires ibl-neuropixel install", From 2a199e9bc11a373429919cd6d6d237f41dc0be81 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Wed, 19 Mar 2025 10:51:55 +0000 Subject: [PATCH 06/13] update `preprocessinglist` --- .../preprocessing/preprocessinglist.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/preprocessinglist.py b/src/spikeinterface/preprocessing/preprocessinglist.py index b4e1194711..bf958f2d17 100644 --- a/src/spikeinterface/preprocessing/preprocessinglist.py +++ b/src/spikeinterface/preprocessing/preprocessinglist.py @@ -37,8 +37,13 @@ from .zero_channel_pad import ZeroChannelPaddedRecording, zero_channel_pad from .deepinterpolation import DeepInterpolatedRecording, deepinterpolate, train_deepinterpolation from .highpass_spatial_filter import HighpassSpatialFilterRecording, highpass_spatial_filter -from .interpolate_bad_channels import InterpolateBadChannelsRecording, interpolate_bad_channels -from .detect_bad_channels import RemoveBadChannelsRecording, remove_bad_channels +from .interpolate_bad_channels import ( + DetectAndInterpolateBadChannelsRecording, + detect_and_interpolate_bad_channels, + InterpolateBadChannelsRecording, + interpolate_bad_channels, +) +from .detect_bad_channels import DetectAndRemoveBadChannelsRecording, detect_and_remove_bad_channels from .average_across_direction import AverageAcrossDirectionRecording, average_across_direction from .directional_derivative import DirectionalDerivativeRecording, directional_derivative from .depth_order import DepthOrderRecording, depth_order @@ -80,7 +85,8 @@ DirectionalDerivativeRecording, AstypeRecording, UnsignedToSignedRecording, - RemoveBadChannelsRecording, + DetectAndRemoveBadChannelsRecording, + DetectAndInterpolateBadChannelsRecording, ] preprocesser_dict = {pp_class.name: pp_class for pp_class in preprocessers_full_list} From 6c2fdeefea8b77015fca8bc8239de16682f37ebb Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Thu, 20 Mar 2025 14:33:37 +0000 Subject: [PATCH 07/13] refactor kwarg update --- .../preprocessing/detect_bad_channels.py | 18 ++++++++++++------ .../preprocessing/interpolate_bad_channels.py | 15 ++++++++------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index c49cc48c60..20bd3a5ebe 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -99,11 +99,6 @@ def __init__( **detect_bad_channels_kwargs, ): - # get the default parameters from `detect_bad_channels`, and update with any user-specified parameters. - sig = signature(detect_bad_channels) - updated_detect_bad_channels_kwargs = {k: v.default for k, v in sig.parameters.items() if k != "recording"} - updated_detect_bad_channels_kwargs.update(detect_bad_channels_kwargs) - if bad_channel_ids is None: bad_channel_ids, channel_labels = detect_bad_channels(recording=recording, **detect_bad_channels_kwargs) else: @@ -121,7 +116,9 @@ def __init__( self._kwargs.update({"bad_channel_ids": bad_channel_ids}) if channel_labels is not None: self._kwargs.update({"channel_labels": channel_labels}) - self._kwargs.update(updated_detect_bad_channels_kwargs) + + all_bad_channels_kwargs = _get_all_detect_bad_channel_kwargs(detect_bad_channels_kwargs) + self._kwargs.update(all_bad_channels_kwargs) detect_and_remove_bad_channels = define_function_handling_dict_from_class( @@ -132,6 +129,15 @@ def __init__( ) +def _get_all_detect_bad_channel_kwargs(detect_bad_channels_kwargs): + """Get the default parameters from `detect_bad_channels`, and update with any user-specified parameters.""" + + sig = signature(detect_bad_channels) + all_detect_bad_channels_kwargs = {k: v.default for k, v in sig.parameters.items() if k != "recording"} + all_detect_bad_channels_kwargs.update(detect_bad_channels_kwargs) + return all_detect_bad_channels_kwargs + + def detect_bad_channels( recording: BaseRecording, method: str = "coherence+psd", diff --git a/src/spikeinterface/preprocessing/interpolate_bad_channels.py b/src/spikeinterface/preprocessing/interpolate_bad_channels.py index 8507cae1d7..3e67d853dd 100644 --- a/src/spikeinterface/preprocessing/interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/interpolate_bad_channels.py @@ -5,7 +5,11 @@ from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment, BaseRecording from spikeinterface.core.core_tools import define_function_handling_dict_from_class from spikeinterface.preprocessing import preprocessing_tools -from .detect_bad_channels import _bad_channel_detection_kwargs_doc, detect_bad_channels +from .detect_bad_channels import ( + _bad_channel_detection_kwargs_doc, + detect_bad_channels, + _get_all_detect_bad_channel_kwargs, +) from inspect import signature @@ -106,11 +110,6 @@ def __init__( **detect_bad_channels_kwargs, ): - # get the default parameters from `detect_bad_channels`, and update with any user-specified parameters. - sig = signature(detect_bad_channels) - updated_detect_bad_channels_kwargs = {k: v.default for k, v in sig.parameters.items() if k != "recording"} - updated_detect_bad_channels_kwargs.update(detect_bad_channels_kwargs) - if bad_channel_ids is None: bad_channel_ids, channel_labels = detect_bad_channels(recording=recording, **detect_bad_channels_kwargs) else: @@ -127,7 +126,9 @@ def __init__( self._kwargs.update({"bad_channel_ids": bad_channel_ids}) if channel_labels is not None: self._kwargs.update({"channel_labels": channel_labels}) - self._kwargs.update(updated_detect_bad_channels_kwargs) + + all_bad_channels_kwargs = _get_all_detect_bad_channel_kwargs(detect_bad_channels_kwargs) + self._kwargs.update(all_bad_channels_kwargs) detect_and_interpolate_bad_channels = define_function_handling_dict_from_class( From 42708853a7b6f8e4135f3fc168fc75dfcac01dbd Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Tue, 1 Apr 2025 15:30:27 +0100 Subject: [PATCH 08/13] add _precomputable_kwarg_names --- src/spikeinterface/preprocessing/detect_bad_channels.py | 2 ++ src/spikeinterface/preprocessing/interpolate_bad_channels.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index 20bd3a5ebe..da89589772 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -92,6 +92,8 @@ class DetectAndRemoveBadChannelsRecording(ChannelSliceRecording): The recording with bad channels removed """ + _precomputable_kwarg_names = ["bad_channel_ids"] + def __init__( self, recording: BaseRecording, diff --git a/src/spikeinterface/preprocessing/interpolate_bad_channels.py b/src/spikeinterface/preprocessing/interpolate_bad_channels.py index 3e67d853dd..36b3e44ce0 100644 --- a/src/spikeinterface/preprocessing/interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/interpolate_bad_channels.py @@ -103,13 +103,14 @@ class DetectAndInterpolateBadChannelsRecording(InterpolateBadChannelsRecording): The recording with bad channels removed """ + _precomputable_kwarg_names = ["bad_channel_ids"] + def __init__( self, recording: BaseRecording, bad_channel_ids=None, **detect_bad_channels_kwargs, ): - if bad_channel_ids is None: bad_channel_ids, channel_labels = detect_bad_channels(recording=recording, **detect_bad_channels_kwargs) else: From 658a0158734585ee09707a67018cfba0cf171205 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Thu, 3 Apr 2025 11:27:36 +0100 Subject: [PATCH 09/13] recording -> parent_recording --- .../preprocessing/detect_bad_channels.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index da89589772..405da9d05c 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -96,22 +96,24 @@ class DetectAndRemoveBadChannelsRecording(ChannelSliceRecording): def __init__( self, - recording: BaseRecording, + parent_recording: BaseRecording, bad_channel_ids=None, **detect_bad_channels_kwargs, ): if bad_channel_ids is None: - bad_channel_ids, channel_labels = detect_bad_channels(recording=recording, **detect_bad_channels_kwargs) + bad_channel_ids, channel_labels = detect_bad_channels( + recording=parent_recording, **detect_bad_channels_kwargs + ) else: channel_labels = None - self._main_ids = recording.get_channel_ids() + self._main_ids = parent_recording.get_channel_ids() new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, bad_channel_ids)] ChannelSliceRecording.__init__( self, - recording, + parent_recording=parent_recording, channel_ids=new_channel_ids, ) From 1f1d8c8117ffeac36ef7e474c55ee5eaaaaecec4 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Wed, 30 Apr 2025 14:22:05 +0100 Subject: [PATCH 10/13] remove _main_ids from detect_and_interpolate --- src/spikeinterface/preprocessing/interpolate_bad_channels.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/interpolate_bad_channels.py b/src/spikeinterface/preprocessing/interpolate_bad_channels.py index 36b3e44ce0..33f1de04ec 100644 --- a/src/spikeinterface/preprocessing/interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/interpolate_bad_channels.py @@ -116,8 +116,6 @@ def __init__( else: channel_labels = None - self._main_ids = recording.get_channel_ids() - InterpolateBadChannelsRecording.__init__( self, recording, From 291993f5ae42083f1dd493428a3b3bbcefe69d89 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Wed, 28 May 2025 13:09:17 +0100 Subject: [PATCH 11/13] channel_labels for prov --- .../preprocessing/detect_bad_channels.py | 10 ++++++++-- .../preprocessing/interpolate_bad_channels.py | 5 ++++- .../preprocessing/tests/test_detect_bad_channels.py | 2 +- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index 405da9d05c..e7b50c6935 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -84,7 +84,10 @@ class DetectAndRemoveBadChannelsRecording(ChannelSliceRecording): {} bad_channel_ids : np.array | list | None, default: None - If given, these are used rather than being dected. + If given, these are used rather than being detected. + channel_labels : np.array | list | None, default: None + If given, these are labels given to the channels by the + detection process. Only intended for use when loading. Returns ------- @@ -98,6 +101,7 @@ def __init__( self, parent_recording: BaseRecording, bad_channel_ids=None, + channel_labels=None, **detect_bad_channels_kwargs, ): @@ -137,7 +141,9 @@ def _get_all_detect_bad_channel_kwargs(detect_bad_channels_kwargs): """Get the default parameters from `detect_bad_channels`, and update with any user-specified parameters.""" sig = signature(detect_bad_channels) - all_detect_bad_channels_kwargs = {k: v.default for k, v in sig.parameters.items() if k != "recording"} + all_detect_bad_channels_kwargs = { + k: v.default for k, v in sig.parameters.items() if k not in ["recording", "parent_recording"] + } all_detect_bad_channels_kwargs.update(detect_bad_channels_kwargs) return all_detect_bad_channels_kwargs diff --git a/src/spikeinterface/preprocessing/interpolate_bad_channels.py b/src/spikeinterface/preprocessing/interpolate_bad_channels.py index 33f1de04ec..e345770273 100644 --- a/src/spikeinterface/preprocessing/interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/interpolate_bad_channels.py @@ -95,7 +95,10 @@ class DetectAndInterpolateBadChannelsRecording(InterpolateBadChannelsRecording): {} bad_channel_ids : np.array | list | None, default: None - If given, these are used rather than being dected. + If given, these are used rather than being detected. + channel_labels : np.array | list | None, default: None + If given, these are labels given to the channels by the + detection process. Only intended for use when loading. Returns ------- diff --git a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py index 38bb35749e..713a2f91e9 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py @@ -46,7 +46,7 @@ def test_remove_bad_channel(): assert set(new_rec._kwargs["channel_ids"]) == set(good_channel_ids) assert new_rec._kwargs["noisy_channel_threshold"] == 0 - # now apply `detec_bad_channels` directly and see that the outputs matches + # now apply `detect_bad_channels` directly and see that the outputs matches bad_channel_ids_from_function, channel_labels_from_function = detect_bad_channels( recording, noisy_channel_threshold=0, seed=1205 ) From 56da71c8f9841641e94e4bc67be69b9e556f9b07 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Wed, 28 May 2025 13:16:29 +0100 Subject: [PATCH 12/13] add some docs --- doc/how_to/process_by_channel_group.rst | 3 ++- doc/modules/preprocessing.rst | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/doc/how_to/process_by_channel_group.rst b/doc/how_to/process_by_channel_group.rst index 2b9e0e12a9..334f83b247 100644 --- a/doc/how_to/process_by_channel_group.rst +++ b/doc/how_to/process_by_channel_group.rst @@ -97,12 +97,13 @@ to any preprocessing function. shifted_recordings = spre.phase_shift(split_recording_dict) filtered_recording = spre.bandpass_filter(shifted_recording) referenced_recording = spre.common_reference(filtered_recording) + good_channels_recording = spre.detect_and_remove_bad_channels(filtered_recording) We can then aggregate the recordings back together using the ``aggregate_channels`` function .. code-block:: python - combined_preprocessed_recording = aggregate_channels(referenced_recording) + combined_preprocessed_recording = aggregate_channels(good_channels_recording) Now, when ``combined_preprocessed_recording`` is used in sorting, plotting, or whenever calling its :py:func:`~get_traces` method, the data will have been diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 4cb16d6855..e5b915844b 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -247,6 +247,18 @@ interpolated with the :code:`interpolate_bad_channels()` function (channels labe # Case 2 : interpolate then rec_clean = interpolate_bad_channels(recording=rec, bad_channel_ids=bad_channel_ids) +Once you have tested these functions and decided on your workflow, you can use the `detect_and_*` +functions to do everything at once. These return a Preprocessor class, so are consistent with +the "chain" concept for this module. For example: + +.. code-block:: python + + # detect and remove bad channels + rec_only_good_channels = detect_and_remove_bad_channels(recording=rec) + + # detect and interpolate the bad channels + rec_interpolated_channels = detect_and_interpolate_bad_channels(remove_channel_ids=bad_channel_ids) + * :py:func:`~spikeinterface.preprocessing.detect_bad_channels()` * :py:func:`~spikeinterface.preprocessing.interpolate_bad_channels()` From a2a6f18b6d9abc12fa554038f7143f2759901b2b Mon Sep 17 00:00:00 2001 From: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 12 Jun 2025 09:44:56 +0100 Subject: [PATCH 13/13] Update src/spikeinterface/preprocessing/detect_bad_channels.py Co-authored-by: Garcia Samuel --- src/spikeinterface/preprocessing/detect_bad_channels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index e7b50c6935..42ffd712d8 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -95,7 +95,7 @@ class DetectAndRemoveBadChannelsRecording(ChannelSliceRecording): The recording with bad channels removed """ - _precomputable_kwarg_names = ["bad_channel_ids"] + _precomputable_kwarg_names = ["bad_channel_ids", "channel_labels"] def __init__( self,