diff --git a/doc/how_to/process_by_channel_group.rst b/doc/how_to/process_by_channel_group.rst index 2b9e0e12a9..334f83b247 100644 --- a/doc/how_to/process_by_channel_group.rst +++ b/doc/how_to/process_by_channel_group.rst @@ -97,12 +97,13 @@ to any preprocessing function. shifted_recordings = spre.phase_shift(split_recording_dict) filtered_recording = spre.bandpass_filter(shifted_recording) referenced_recording = spre.common_reference(filtered_recording) + good_channels_recording = spre.detect_and_remove_bad_channels(filtered_recording) We can then aggregate the recordings back together using the ``aggregate_channels`` function .. code-block:: python - combined_preprocessed_recording = aggregate_channels(referenced_recording) + combined_preprocessed_recording = aggregate_channels(good_channels_recording) Now, when ``combined_preprocessed_recording`` is used in sorting, plotting, or whenever calling its :py:func:`~get_traces` method, the data will have been diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 4cb16d6855..fcffe08c05 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -247,6 +247,18 @@ interpolated with the :code:`interpolate_bad_channels()` function (channels labe # Case 2 : interpolate then rec_clean = interpolate_bad_channels(recording=rec, bad_channel_ids=bad_channel_ids) +Once you have tested these functions and decided on your workflow, you can use the `detect_and_*` +functions to do everything at once. These return a Preprocessor class, so are consistent with +the "chain" concept for this module. For example: + +.. code-block:: python + + # detect and remove bad channels + rec_only_good_channels = detect_and_remove_bad_channels(recording=rec) + + # detect and interpolate the bad channels + rec_interpolated_channels = detect_and_interpolate_bad_channels(recording=rec) + * :py:func:`~spikeinterface.preprocessing.detect_bad_channels()` * :py:func:`~spikeinterface.preprocessing.interpolate_bad_channels()` diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index 2175351f0b..42ffd712d8 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -4,8 +4,148 @@ import numpy as np from typing import Literal +from spikeinterface.core.core_tools import define_function_handling_dict_from_class from .filter import highpass_filter from spikeinterface.core import get_random_data_chunks, order_channels_by_depth, BaseRecording +from spikeinterface.core.channelslice import ChannelSliceRecording + +from inspect import signature + +_bad_channel_detection_kwargs_doc = """Different methods are implemented: + +* std : threhshold on channel standard deviations + If the standard deviation of a channel is greater than `std_mad_threshold` times the median of all + channels standard deviations, the channel is flagged as noisy +* mad : same as std, but using median absolute deviations instead +* coeherence+psd : method developed by the International Brain Laboratory that detects bad channels of three types: + * Dead channels are those with low similarity to the surrounding channels (n=`n_neighbors` median) + * Noise channels are those with power at >80% Nyquist above the psd_hf_threshold (default 0.02 uV^2 / Hz) + and a high coherence with "far away" channels" + * Out of brain channels are contigious regions of channels dissimilar to the median of all channels + at the top end of the probe (i.e. large channel number) +* neighborhood_r2 + A method tuned for LFP use-cases, where channels should be highly correlated with their spatial + neighbors. This method estimates the correlation of each channel with the median of its spatial + neighbors, and considers channels bad when this correlation is too small. + +Parameters +---------- +recording : BaseRecording + The recording for which bad channels are detected +method : "coeherence+psd" | "std" | "mad" | "neighborhood_r2", default: "coeherence+psd" + The method to be used for bad channel detection +std_mad_threshold : float, default: 5 + The standard deviation/mad multiplier threshold +psd_hf_threshold : float, default: 0.02 + For coherence+psd - an absolute threshold (uV^2/Hz) used as a cutoff for noise channels. + Channels with average power at >80% Nyquist larger than this threshold + will be labeled as noise +dead_channel_threshold : float, default: -0.5 + For coherence+psd - threshold for channel coherence below which channels are labeled as dead +noisy_channel_threshold : float, default: 1 + Threshold for channel coherence above which channels are labeled as noisy (together with psd condition) +outside_channel_threshold : float, default: -0.75 + For coherence+psd - threshold for channel coherence above which channels at the edge of the recording are marked as outside + of the brain +outside_channels_location : "top" | "bottom" | "both", default: "top" + For coherence+psd - location of the outside channels. If "top", only the channels at the top of the probe can be + marked as outside channels. If "bottom", only the channels at the bottom of the probe can be + marked as outside channels. If "both", both the channels at the top and bottom of the probe can be + marked as outside channels +n_neighbors : int, default: 11 + For coeherence+psd - number of channel neighbors to compute median filter (needs to be odd) +nyquist_threshold : float, default: 0.8 + For coherence+psd - frequency with respect to Nyquist (Fn=1) above which the mean of the PSD is calculated and compared + with psd_hf_threshold +direction : "x" | "y" | "z", default: "y" + For coherence+psd - the depth dimension +highpass_filter_cutoff : float, default: 300 + If the recording is not filtered, the cutoff frequency of the highpass filter +chunk_duration_s : float, default: 0.5 + Duration of each chunk +num_random_chunks : int, default: 100 + Number of random chunks + Having many chunks is important for reproducibility. +welch_window_ms : float, default: 10 + Window size for the scipy.signal.welch that will be converted to nperseg +neighborhood_r2_threshold : float, default: 0.95 + R^2 threshold for the neighborhood_r2 method. +neighborhood_r2_radius_um : float, default: 30 + Spatial radius below which two channels are considered neighbors in the neighborhood_r2 method. +seed : int or None, default: None + The random seed to extract chunks +""" + + +class DetectAndRemoveBadChannelsRecording(ChannelSliceRecording): + """ + Detects and removes bad channels. If `bad_channel_ids` are given, + the detection is skipped and uses these instead. + + {} + bad_channel_ids : np.array | list | None, default: None + If given, these are used rather than being detected. + channel_labels : np.array | list | None, default: None + If given, these are labels given to the channels by the + detection process. Only intended for use when loading. + + Returns + ------- + removed_bad_channels_recording : DetectAndRemoveBadChannelsRecording + The recording with bad channels removed + """ + + _precomputable_kwarg_names = ["bad_channel_ids", "channel_labels"] + + def __init__( + self, + parent_recording: BaseRecording, + bad_channel_ids=None, + channel_labels=None, + **detect_bad_channels_kwargs, + ): + + if bad_channel_ids is None: + bad_channel_ids, channel_labels = detect_bad_channels( + recording=parent_recording, **detect_bad_channels_kwargs + ) + else: + channel_labels = None + + self._main_ids = parent_recording.get_channel_ids() + new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, bad_channel_ids)] + + ChannelSliceRecording.__init__( + self, + parent_recording=parent_recording, + channel_ids=new_channel_ids, + ) + + self._kwargs.update({"bad_channel_ids": bad_channel_ids}) + if channel_labels is not None: + self._kwargs.update({"channel_labels": channel_labels}) + + all_bad_channels_kwargs = _get_all_detect_bad_channel_kwargs(detect_bad_channels_kwargs) + self._kwargs.update(all_bad_channels_kwargs) + + +detect_and_remove_bad_channels = define_function_handling_dict_from_class( + source_class=DetectAndRemoveBadChannelsRecording, name="detect_and_remove_bad_channels" +) +DetectAndRemoveBadChannelsRecording.__doc__ = DetectAndRemoveBadChannelsRecording.__doc__.format( + _bad_channel_detection_kwargs_doc +) + + +def _get_all_detect_bad_channel_kwargs(detect_bad_channels_kwargs): + """Get the default parameters from `detect_bad_channels`, and update with any user-specified parameters.""" + + sig = signature(detect_bad_channels) + all_detect_bad_channels_kwargs = { + k: v.default for k, v in sig.parameters.items() if k not in ["recording", "parent_recording"] + } + all_detect_bad_channels_kwargs.update(detect_bad_channels_kwargs) + return all_detect_bad_channels_kwargs def detect_bad_channels( @@ -32,69 +172,7 @@ def detect_bad_channels( Perform bad channel detection. The recording is assumed to be filtered. If not, a highpass filter is applied on the fly. - Different methods are implemented: - - * std : threhshold on channel standard deviations - If the standard deviation of a channel is greater than `std_mad_threshold` times the median of all - channels standard deviations, the channel is flagged as noisy - * mad : same as std, but using median absolute deviations instead - * coeherence+psd : method developed by the International Brain Laboratory that detects bad channels of three types: - * Dead channels are those with low similarity to the surrounding channels (n=`n_neighbors` median) - * Noise channels are those with power at >80% Nyquist above the psd_hf_threshold (default 0.02 uV^2 / Hz) - and a high coherence with "far away" channels" - * Out of brain channels are contigious regions of channels dissimilar to the median of all channels - at the top end of the probe (i.e. large channel number) - * neighborhood_r2 - A method tuned for LFP use-cases, where channels should be highly correlated with their spatial - neighbors. This method estimates the correlation of each channel with the median of its spatial - neighbors, and considers channels bad when this correlation is too small. - - Parameters - ---------- - recording : BaseRecording - The recording for which bad channels are detected - method : "coeherence+psd" | "std" | "mad" | "neighborhood_r2", default: "coeherence+psd" - The method to be used for bad channel detection - std_mad_threshold : float, default: 5 - The standard deviation/mad multiplier threshold - psd_hf_threshold : float, default: 0.02 - For coherence+psd - an absolute threshold (uV^2/Hz) used as a cutoff for noise channels. - Channels with average power at >80% Nyquist larger than this threshold - will be labeled as noise - dead_channel_threshold : float, default: -0.5 - For coherence+psd - threshold for channel coherence below which channels are labeled as dead - noisy_channel_threshold : float, default: 1 - Threshold for channel coherence above which channels are labeled as noisy (together with psd condition) - outside_channel_threshold : float, default: -0.75 - For coherence+psd - threshold for channel coherence above which channels at the edge of the recording are marked as outside - of the brain - outside_channels_location : "top" | "bottom" | "both", default: "top" - For coherence+psd - location of the outside channels. If "top", only the channels at the top of the probe can be - marked as outside channels. If "bottom", only the channels at the bottom of the probe can be - marked as outside channels. If "both", both the channels at the top and bottom of the probe can be - marked as outside channels - n_neighbors : int, default: 11 - For coeherence+psd - number of channel neighbors to compute median filter (needs to be odd) - nyquist_threshold : float, default: 0.8 - For coherence+psd - frequency with respect to Nyquist (Fn=1) above which the mean of the PSD is calculated and compared - with psd_hf_threshold - direction : "x" | "y" | "z", default: "y" - For coherence+psd - the depth dimension - highpass_filter_cutoff : float, default: 300 - If the recording is not filtered, the cutoff frequency of the highpass filter - chunk_duration_s : float, default: 0.5 - Duration of each chunk - num_random_chunks : int, default: 100 - Number of random chunks - Having many chunks is important for reproducibility. - welch_window_ms : float, default: 10 - Window size for the scipy.signal.welch that will be converted to nperseg - neighborhood_r2_threshold : float, default: 0.95 - R^2 threshold for the neighborhood_r2 method. - neighborhood_r2_radius_um : float, default: 30 - Spatial radius below which two channels are considered neighbors in the neighborhood_r2 method. - seed : int or None, default: None - The random seed to extract chunks + {} Returns ------- @@ -269,6 +347,9 @@ def detect_bad_channels( return bad_channel_ids, channel_labels +detect_bad_channels.__doc__ = detect_bad_channels.__doc__.format(_bad_channel_detection_kwargs_doc) + + # ---------------------------------------------------------------------------------------------- # IBL Detect Bad Channels # ---------------------------------------------------------------------------------------------- diff --git a/src/spikeinterface/preprocessing/interpolate_bad_channels.py b/src/spikeinterface/preprocessing/interpolate_bad_channels.py index 87bb0f936b..e345770273 100644 --- a/src/spikeinterface/preprocessing/interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/interpolate_bad_channels.py @@ -2,9 +2,15 @@ import numpy as np -from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment +from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment, BaseRecording from spikeinterface.core.core_tools import define_function_handling_dict_from_class from spikeinterface.preprocessing import preprocessing_tools +from .detect_bad_channels import ( + _bad_channel_detection_kwargs_doc, + detect_bad_channels, + _get_all_detect_bad_channel_kwargs, +) +from inspect import signature class InterpolateBadChannelsRecording(BasePreprocessor): @@ -82,6 +88,59 @@ def check_inputs(self, recording, bad_channel_ids): raise NotImplementedError("Channel spacing units must be um") +class DetectAndInterpolateBadChannelsRecording(InterpolateBadChannelsRecording): + """ + Detects and interpolates bad channels. If `bad_channel_ids` are given, + the detection is skipped and uses these instead. + + {} + bad_channel_ids : np.array | list | None, default: None + If given, these are used rather than being detected. + channel_labels : np.array | list | None, default: None + If given, these are labels given to the channels by the + detection process. Only intended for use when loading. + + Returns + ------- + interpolated_bad_channels_recording : DetectAndInterpolateBadChannelsRecording + The recording with bad channels removed + """ + + _precomputable_kwarg_names = ["bad_channel_ids"] + + def __init__( + self, + recording: BaseRecording, + bad_channel_ids=None, + **detect_bad_channels_kwargs, + ): + if bad_channel_ids is None: + bad_channel_ids, channel_labels = detect_bad_channels(recording=recording, **detect_bad_channels_kwargs) + else: + channel_labels = None + + InterpolateBadChannelsRecording.__init__( + self, + recording, + bad_channel_ids=bad_channel_ids, + ) + + self._kwargs.update({"bad_channel_ids": bad_channel_ids}) + if channel_labels is not None: + self._kwargs.update({"channel_labels": channel_labels}) + + all_bad_channels_kwargs = _get_all_detect_bad_channel_kwargs(detect_bad_channels_kwargs) + self._kwargs.update(all_bad_channels_kwargs) + + +detect_and_interpolate_bad_channels = define_function_handling_dict_from_class( + source_class=DetectAndInterpolateBadChannelsRecording, name="detect_and_interpolate_bad_channels" +) +DetectAndInterpolateBadChannelsRecording.__doc__ = DetectAndInterpolateBadChannelsRecording.__doc__.format( + _bad_channel_detection_kwargs_doc +) + + class InterpolateBadChannelsSegment(BasePreprocessorSegment): def __init__(self, parent_recording_segment, good_channel_indices, bad_channel_indices, weights): BasePreprocessorSegment.__init__(self, parent_recording_segment) diff --git a/src/spikeinterface/preprocessing/preprocessing_classes.py b/src/spikeinterface/preprocessing/preprocessing_classes.py index 705ed3428b..e7f42bc73b 100644 --- a/src/spikeinterface/preprocessing/preprocessing_classes.py +++ b/src/spikeinterface/preprocessing/preprocessing_classes.py @@ -38,7 +38,13 @@ from .zero_channel_pad import ZeroChannelPaddedRecording, zero_channel_pad from .deepinterpolation import DeepInterpolatedRecording, deepinterpolate, train_deepinterpolation from .highpass_spatial_filter import HighpassSpatialFilterRecording, highpass_spatial_filter -from .interpolate_bad_channels import InterpolateBadChannelsRecording, interpolate_bad_channels +from .interpolate_bad_channels import ( + DetectAndInterpolateBadChannelsRecording, + detect_and_interpolate_bad_channels, + InterpolateBadChannelsRecording, + interpolate_bad_channels, +) +from .detect_bad_channels import DetectAndRemoveBadChannelsRecording, detect_and_remove_bad_channels from .average_across_direction import AverageAcrossDirectionRecording, average_across_direction from .directional_derivative import DirectionalDerivativeRecording, directional_derivative from .depth_order import DepthOrderRecording, depth_order @@ -63,6 +69,9 @@ # re-reference CommonReferenceRecording: common_reference, PhaseShiftRecording: phase_shift, + # bad channel detection/interpolation + DetectAndRemoveBadChannelsRecording: detect_and_remove_bad_channels, + DetectAndInterpolateBadChannelsRecording: detect_and_interpolate_bad_channels, # misc RectifyRecording: rectify, ClipRecording: clip, diff --git a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py index f0e22c4afd..119cddeee1 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py @@ -6,8 +6,10 @@ from spikeinterface import NumpyRecording, get_random_data_chunks from probeinterface import generate_linear_probe +from spikeinterface.generation import generate_recording + from spikeinterface.core import generate_recording -from spikeinterface.preprocessing import detect_bad_channels, highpass_filter +from spikeinterface.preprocessing import detect_bad_channels, highpass_filter, detect_and_remove_bad_channels # WARNING : this is not this package https://pypi.org/project/neurodsp/ # BUT this one https://github.com/int-brain-lab/ibl-neuropixel @@ -20,6 +22,45 @@ HAVE_NPIX = False +def test_remove_bad_channel(): + """ + Generate a recording, then remove bad channels with a low noise threshold, + so that some units are removed. Then check that the new recording has none + of the bad channels still in it and that the one changed kwarg is successfully + propogated to the new recording. + """ + + recording = generate_recording(durations=[5, 6], seed=1205, num_channels=8) + recording.set_channel_offsets(0) + recording.set_channel_gains(1) + + # set noisy_channel_threshold so that we do detect some bad channels + new_rec = detect_and_remove_bad_channels(recording, noisy_channel_threshold=0, seed=1205) + + # make sure they are removed + bad_channel_ids = new_rec._kwargs["bad_channel_ids"] + assert len(set(bad_channel_ids).intersection(new_rec.channel_ids)) == 0 + # and the good ones are kept + good_channel_ids = recording.channel_ids[~np.isin(recording.channel_ids, bad_channel_ids)] + assert set(good_channel_ids) == set(new_rec.channel_ids) + + # and that the kwarg is propogatged to the kwargs of new_rec. + assert set(new_rec._kwargs["channel_ids"]) == set(good_channel_ids) + assert new_rec._kwargs["noisy_channel_threshold"] == 0 + + # now apply `detect_bad_channels` directly and see that the outputs matches + bad_channel_ids_from_function, channel_labels_from_function = detect_bad_channels( + recording, noisy_channel_threshold=0, seed=1205 + ) + + assert np.all(new_rec._kwargs["bad_channel_ids"] == bad_channel_ids_from_function) + assert np.all(new_rec._kwargs["channel_labels"] == channel_labels_from_function) + + new_rec_from_function = recording.remove_channels(remove_channel_ids=bad_channel_ids_from_function) + + assert np.all(new_rec_from_function.channel_ids == new_rec.channel_ids) + + def test_detect_bad_channels_std_mad(): num_channels = 4 sampling_frequency = 30000.0 diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index 06bde4e3d1..e05cccebd2 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -7,7 +7,8 @@ import spikeinterface.extractors as se from spikeinterface.core.generate import generate_recording import importlib.util - +from spikeinterface.preprocessing.interpolate_bad_channels import detect_and_interpolate_bad_channels +from spikeinterface.preprocessing.detect_bad_channels import detect_bad_channels ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) DEBUG = False @@ -23,6 +24,31 @@ # ------------------------------------------------------------------------------- +def test_detect_and_interpolate_bad_channel(): + """ + Generate a recording, then remove bad channels with a low noise threshold, so that + some units are removed. Then check that the new recording is an interpolated + recording and that kwargs are successfully propogated to the new recording. + """ + + recording = generate_recording(durations=[5, 6], seed=1205, num_channels=8) + recording.set_channel_offsets(0) + recording.set_channel_gains(1) + + # find the bad channels directly + bad_channel_ids, _ = detect_bad_channels(recording, noisy_channel_threshold=0, seed=1205) + + # set noisy_channel_threshold so that we do detect some bad channels + new_rec = detect_and_interpolate_bad_channels(recording, noisy_channel_threshold=0, seed=1205) + + # make sure they are in the new recording kwargs + bad_channel_ids_from_rec = new_rec._kwargs["bad_channel_ids"] + assert set(bad_channel_ids) == set(bad_channel_ids_from_rec) + + # and that the kwarg is propogatged to the kwargs of new_rec. + assert new_rec._kwargs["noisy_channel_threshold"] == 0 + + @pytest.mark.skipif( importlib.util.find_spec("neurodsp") is not None or importlib.util.find_spec("spikeglx") or ON_GITHUB, reason="Only local. Requires ibl-neuropixel install",