|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +from spikeinterface.core.core_tools import define_function_handling_dict_from_class |
| 6 | +from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording |
| 7 | +from spikeinterface.preprocessing.rectify import RectifyRecording |
| 8 | +from spikeinterface.preprocessing.common_reference import CommonReferenceRecording |
| 9 | +from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording |
| 10 | +from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs |
| 11 | +from spikeinterface.core.recording_tools import get_noise_levels |
| 12 | +from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype |
| 13 | +import numpy as np |
| 14 | + |
| 15 | + |
| 16 | +class DetectThresholdCrossing(PeakDetector): |
| 17 | + |
| 18 | + name = "threshold_crossings" |
| 19 | + preferred_mp_context = None |
| 20 | + |
| 21 | + def __init__( |
| 22 | + self, |
| 23 | + recording, |
| 24 | + detect_threshold=5, |
| 25 | + noise_levels=None, |
| 26 | + seed=None, |
| 27 | + noise_levels_kwargs=dict(), |
| 28 | + ): |
| 29 | + PeakDetector.__init__(self, recording, return_output=True) |
| 30 | + if noise_levels is None: |
| 31 | + random_slices_kwargs = noise_levels_kwargs.pop("random_slices_kwargs", {}).copy() |
| 32 | + random_slices_kwargs["seed"] = seed |
| 33 | + noise_levels = get_noise_levels(recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) |
| 34 | + self.abs_thresholds = noise_levels * detect_threshold |
| 35 | + self._dtype = np.dtype(base_peak_dtype + [("front", "bool")]) |
| 36 | + |
| 37 | + def get_trace_margin(self): |
| 38 | + return 0 |
| 39 | + |
| 40 | + def get_dtype(self): |
| 41 | + return self._dtype |
| 42 | + |
| 43 | + def compute(self, traces, start_frame, end_frame, segment_index, max_margin): |
| 44 | + z = np.median(traces / self.abs_thresholds, 1) |
| 45 | + threshold_mask = np.diff((z > 1) != 0, axis=0) |
| 46 | + indices = np.flatnonzero(threshold_mask) |
| 47 | + threshold_crossings = np.zeros(indices.size, dtype=self._dtype) |
| 48 | + threshold_crossings["sample_index"] = indices |
| 49 | + threshold_crossings["front"][::2] = True |
| 50 | + threshold_crossings["front"][1::2] = False |
| 51 | + return (threshold_crossings,) |
| 52 | + |
| 53 | + |
| 54 | +def detect_period_artifacts_by_envelope( |
| 55 | + recording, |
| 56 | + detect_threshold=5, |
| 57 | + min_duration_ms=50, |
| 58 | + freq_max=20.0, |
| 59 | + seed=None, |
| 60 | + noise_levels=None, |
| 61 | + **noise_levels_kwargs, |
| 62 | +): |
| 63 | + """ |
| 64 | + Docstring for detect_period_artifacts. Function to detect putative artifact periods as threshold crossings of |
| 65 | + a global envelope of the channels. |
| 66 | +
|
| 67 | + Parameters |
| 68 | + ---------- |
| 69 | + recording : RecordingExtractor |
| 70 | + The recording extractor to detect putative artifacts |
| 71 | + detect_threshold : float, default: 5 |
| 72 | + The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` |
| 73 | + freq_max : float, default: 20 |
| 74 | + The maximum frequency for the low pass filter used |
| 75 | + min_duration_ms : float, default: 50 |
| 76 | + The minimum duration for a threshold crossing to be considered as an artefact. |
| 77 | + noise_levels : array |
| 78 | + Noise levels if already computed |
| 79 | + seed : int | None, default: None |
| 80 | + Random seed for `get_noise_levels`. |
| 81 | + If none, `get_noise_levels` uses `seed=0`. |
| 82 | + **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function |
| 83 | +
|
| 84 | + """ |
| 85 | + |
| 86 | + envelope = RectifyRecording(recording) |
| 87 | + envelope = GaussianFilterRecording(envelope, freq_min=None, freq_max=freq_max) |
| 88 | + envelope = CommonReferenceRecording(envelope) |
| 89 | + |
| 90 | + from spikeinterface.core.node_pipeline import ( |
| 91 | + run_node_pipeline, |
| 92 | + ) |
| 93 | + |
| 94 | + _, job_kwargs = split_job_kwargs(noise_levels_kwargs) |
| 95 | + job_kwargs = fix_job_kwargs(job_kwargs) |
| 96 | + |
| 97 | + node0 = DetectThresholdCrossing( |
| 98 | + recording, detect_threshold=detect_threshold, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs |
| 99 | + ) |
| 100 | + |
| 101 | + threshold_crossings = run_node_pipeline( |
| 102 | + recording, |
| 103 | + [node0], |
| 104 | + job_kwargs, |
| 105 | + job_name="detect threshold crossings", |
| 106 | + ) |
| 107 | + |
| 108 | + order = np.lexsort((threshold_crossings["sample_index"], threshold_crossings["segment_index"])) |
| 109 | + threshold_crossings = threshold_crossings[order] |
| 110 | + |
| 111 | + periods = [] |
| 112 | + fs = recording.sampling_frequency |
| 113 | + max_duration_samples = int(min_duration_ms * fs / 1000) |
| 114 | + num_seg = recording.get_num_segments() |
| 115 | + |
| 116 | + for seg_index in range(num_seg): |
| 117 | + sub_periods = [] |
| 118 | + mask = threshold_crossings["segment_index"] == seg_index |
| 119 | + sub_thr = threshold_crossings[mask] |
| 120 | + if len(sub_thr) > 0: |
| 121 | + local_thr = np.zeros(1, dtype=np.dtype(base_peak_dtype + [("front", "bool")])) |
| 122 | + if not sub_thr["front"][0]: |
| 123 | + local_thr["sample_index"] = 0 |
| 124 | + local_thr["front"] = True |
| 125 | + sub_thr = np.hstack((local_thr, sub_thr)) |
| 126 | + if sub_thr["front"][-1]: |
| 127 | + local_thr["sample_index"] = recording.get_num_samples(seg_index) |
| 128 | + local_thr["front"] = False |
| 129 | + sub_thr = np.hstack((sub_thr, local_thr)) |
| 130 | + |
| 131 | + indices = np.flatnonzero(np.diff(sub_thr["front"])) |
| 132 | + for i, j in zip(indices[:-1], indices[1:]): |
| 133 | + if sub_thr["front"][i]: |
| 134 | + start = sub_thr["sample_index"][i] |
| 135 | + end = sub_thr["sample_index"][j] |
| 136 | + if end - start > max_duration_samples: |
| 137 | + sub_periods.append((start, end)) |
| 138 | + |
| 139 | + periods.append(sub_periods) |
| 140 | + |
| 141 | + return periods, envelope |
| 142 | + |
| 143 | + |
| 144 | +class SilencedArtifactsRecording(SilencedPeriodsRecording): |
| 145 | + """ |
| 146 | + Silence user-defined periods from recording extractor traces. The code will construct |
| 147 | + an enveloppe of the recording (as a low pass filtered version of the traces) and detect |
| 148 | + threshold crossings to identify the periods to silence. The periods are then silenced either |
| 149 | + on a per channel basis or across all channels by replacing the values by zeros or by |
| 150 | + adding gaussian noise with the same variance as the one in the recordings |
| 151 | +
|
| 152 | + Parameters |
| 153 | + ---------- |
| 154 | + recording : RecordingExtractor |
| 155 | + The recording extractor to silence putative artifacts |
| 156 | + detect_threshold : float, default: 5 |
| 157 | + The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` |
| 158 | + freq_max : float, default: 20 |
| 159 | + The maximum frequency for the low pass filter used |
| 160 | + min_duration_ms : float, default: 50 |
| 161 | + The minimum duration for a threshold crossing to be considered as an artefact. |
| 162 | + noise_levels : array |
| 163 | + Noise levels if already computed |
| 164 | + seed : int | None, default: None |
| 165 | + Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. |
| 166 | + If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. |
| 167 | + mode : "zeros" | "noise", default: "zeros" |
| 168 | + Determines what periods are replaced by. Can be one of the following: |
| 169 | +
|
| 170 | + - "zeros": Artifacts are replaced by zeros. |
| 171 | +
|
| 172 | + - "noise": The periods are filled with a gaussion noise that has the |
| 173 | + same variance that the one in the recordings, on a per channel |
| 174 | + basis |
| 175 | + **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function |
| 176 | +
|
| 177 | + Returns |
| 178 | + ------- |
| 179 | + silenced_recording : SilencedArtifactsRecording |
| 180 | + The recording extractor after silencing detected artifacts |
| 181 | + """ |
| 182 | + |
| 183 | + _precomputable_kwarg_names = ["list_periods"] |
| 184 | + |
| 185 | + def __init__( |
| 186 | + self, |
| 187 | + recording, |
| 188 | + detect_threshold=5, |
| 189 | + verbose=False, |
| 190 | + freq_max=20.0, |
| 191 | + min_duration_ms=50, |
| 192 | + mode="zeros", |
| 193 | + noise_levels=None, |
| 194 | + seed=None, |
| 195 | + list_periods=None, |
| 196 | + **noise_levels_kwargs, |
| 197 | + ): |
| 198 | + |
| 199 | + if list_periods is None: |
| 200 | + list_periods, _ = detect_period_artifacts_by_envelope( |
| 201 | + recording, |
| 202 | + detect_threshold=detect_threshold, |
| 203 | + min_duration_ms=min_duration_ms, |
| 204 | + freq_max=freq_max, |
| 205 | + seed=seed, |
| 206 | + noise_levels=noise_levels, |
| 207 | + **noise_levels_kwargs, |
| 208 | + ) |
| 209 | + |
| 210 | + if verbose: |
| 211 | + for i, periods in enumerate(list_periods): |
| 212 | + total_time = np.sum([end - start for start, end in periods]) |
| 213 | + percentage = 100 * total_time / recording.get_num_samples(i) |
| 214 | + print(f"{percentage}% of segment {i} has been flagged as artifactual") |
| 215 | + |
| 216 | + SilencedPeriodsRecording.__init__( |
| 217 | + self, recording, list_periods, mode=mode, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs |
| 218 | + ) |
| 219 | + |
| 220 | + |
| 221 | +# function for API |
| 222 | +silence_artifacts = define_function_handling_dict_from_class( |
| 223 | + source_class=SilencedArtifactsRecording, name="silence_artifacts" |
| 224 | +) |
0 commit comments