Skip to content

Commit d7554f6

Browse files
ygerpre-commit-ci[bot]alejoe91
authored
Remove artefacts based of an envelope (#3715)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Alessio Buccino <alejoe9187@gmail.com>
1 parent db6733d commit d7554f6

4 files changed

Lines changed: 250 additions & 9 deletions

File tree

src/spikeinterface/preprocessing/preprocessing_classes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from .depth_order import DepthOrderRecording, depth_order
5151
from .astype import AstypeRecording, astype
5252
from .unsigned_to_signed import UnsignedToSignedRecording, unsigned_to_signed
53+
from .silence_artifacts import SilencedArtifactsRecording, silence_artifacts
5354

5455
_all_preprocesser_dict = {
5556
# filter stuff
@@ -89,6 +90,7 @@
8990
DirectionalDerivativeRecording: directional_derivative,
9091
AstypeRecording: astype,
9192
UnsignedToSignedRecording: unsigned_to_signed,
93+
SilencedArtifactsRecording: silence_artifacts,
9294
}
9395
# we control import in the preprocessing init by setting an __all__
9496

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
4+
5+
from spikeinterface.core.core_tools import define_function_handling_dict_from_class
6+
from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording
7+
from spikeinterface.preprocessing.rectify import RectifyRecording
8+
from spikeinterface.preprocessing.common_reference import CommonReferenceRecording
9+
from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording
10+
from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs
11+
from spikeinterface.core.recording_tools import get_noise_levels
12+
from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype
13+
import numpy as np
14+
15+
16+
class DetectThresholdCrossing(PeakDetector):
17+
18+
name = "threshold_crossings"
19+
preferred_mp_context = None
20+
21+
def __init__(
22+
self,
23+
recording,
24+
detect_threshold=5,
25+
noise_levels=None,
26+
seed=None,
27+
noise_levels_kwargs=dict(),
28+
):
29+
PeakDetector.__init__(self, recording, return_output=True)
30+
if noise_levels is None:
31+
random_slices_kwargs = noise_levels_kwargs.pop("random_slices_kwargs", {}).copy()
32+
random_slices_kwargs["seed"] = seed
33+
noise_levels = get_noise_levels(recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs)
34+
self.abs_thresholds = noise_levels * detect_threshold
35+
self._dtype = np.dtype(base_peak_dtype + [("front", "bool")])
36+
37+
def get_trace_margin(self):
38+
return 0
39+
40+
def get_dtype(self):
41+
return self._dtype
42+
43+
def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
44+
z = np.median(traces / self.abs_thresholds, 1)
45+
threshold_mask = np.diff((z > 1) != 0, axis=0)
46+
indices = np.flatnonzero(threshold_mask)
47+
threshold_crossings = np.zeros(indices.size, dtype=self._dtype)
48+
threshold_crossings["sample_index"] = indices
49+
threshold_crossings["front"][::2] = True
50+
threshold_crossings["front"][1::2] = False
51+
return (threshold_crossings,)
52+
53+
54+
def detect_period_artifacts_by_envelope(
55+
recording,
56+
detect_threshold=5,
57+
min_duration_ms=50,
58+
freq_max=20.0,
59+
seed=None,
60+
noise_levels=None,
61+
**noise_levels_kwargs,
62+
):
63+
"""
64+
Docstring for detect_period_artifacts. Function to detect putative artifact periods as threshold crossings of
65+
a global envelope of the channels.
66+
67+
Parameters
68+
----------
69+
recording : RecordingExtractor
70+
The recording extractor to detect putative artifacts
71+
detect_threshold : float, default: 5
72+
The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level`
73+
freq_max : float, default: 20
74+
The maximum frequency for the low pass filter used
75+
min_duration_ms : float, default: 50
76+
The minimum duration for a threshold crossing to be considered as an artefact.
77+
noise_levels : array
78+
Noise levels if already computed
79+
seed : int | None, default: None
80+
Random seed for `get_noise_levels`.
81+
If none, `get_noise_levels` uses `seed=0`.
82+
**noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function
83+
84+
"""
85+
86+
envelope = RectifyRecording(recording)
87+
envelope = GaussianFilterRecording(envelope, freq_min=None, freq_max=freq_max)
88+
envelope = CommonReferenceRecording(envelope)
89+
90+
from spikeinterface.core.node_pipeline import (
91+
run_node_pipeline,
92+
)
93+
94+
_, job_kwargs = split_job_kwargs(noise_levels_kwargs)
95+
job_kwargs = fix_job_kwargs(job_kwargs)
96+
97+
node0 = DetectThresholdCrossing(
98+
recording, detect_threshold=detect_threshold, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs
99+
)
100+
101+
threshold_crossings = run_node_pipeline(
102+
recording,
103+
[node0],
104+
job_kwargs,
105+
job_name="detect threshold crossings",
106+
)
107+
108+
order = np.lexsort((threshold_crossings["sample_index"], threshold_crossings["segment_index"]))
109+
threshold_crossings = threshold_crossings[order]
110+
111+
periods = []
112+
fs = recording.sampling_frequency
113+
max_duration_samples = int(min_duration_ms * fs / 1000)
114+
num_seg = recording.get_num_segments()
115+
116+
for seg_index in range(num_seg):
117+
sub_periods = []
118+
mask = threshold_crossings["segment_index"] == seg_index
119+
sub_thr = threshold_crossings[mask]
120+
if len(sub_thr) > 0:
121+
local_thr = np.zeros(1, dtype=np.dtype(base_peak_dtype + [("front", "bool")]))
122+
if not sub_thr["front"][0]:
123+
local_thr["sample_index"] = 0
124+
local_thr["front"] = True
125+
sub_thr = np.hstack((local_thr, sub_thr))
126+
if sub_thr["front"][-1]:
127+
local_thr["sample_index"] = recording.get_num_samples(seg_index)
128+
local_thr["front"] = False
129+
sub_thr = np.hstack((sub_thr, local_thr))
130+
131+
indices = np.flatnonzero(np.diff(sub_thr["front"]))
132+
for i, j in zip(indices[:-1], indices[1:]):
133+
if sub_thr["front"][i]:
134+
start = sub_thr["sample_index"][i]
135+
end = sub_thr["sample_index"][j]
136+
if end - start > max_duration_samples:
137+
sub_periods.append((start, end))
138+
139+
periods.append(sub_periods)
140+
141+
return periods, envelope
142+
143+
144+
class SilencedArtifactsRecording(SilencedPeriodsRecording):
145+
"""
146+
Silence user-defined periods from recording extractor traces. The code will construct
147+
an enveloppe of the recording (as a low pass filtered version of the traces) and detect
148+
threshold crossings to identify the periods to silence. The periods are then silenced either
149+
on a per channel basis or across all channels by replacing the values by zeros or by
150+
adding gaussian noise with the same variance as the one in the recordings
151+
152+
Parameters
153+
----------
154+
recording : RecordingExtractor
155+
The recording extractor to silence putative artifacts
156+
detect_threshold : float, default: 5
157+
The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level`
158+
freq_max : float, default: 20
159+
The maximum frequency for the low pass filter used
160+
min_duration_ms : float, default: 50
161+
The minimum duration for a threshold crossing to be considered as an artefact.
162+
noise_levels : array
163+
Noise levels if already computed
164+
seed : int | None, default: None
165+
Random seed for `get_noise_levels` and `NoiseGeneratorRecording`.
166+
If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`.
167+
mode : "zeros" | "noise", default: "zeros"
168+
Determines what periods are replaced by. Can be one of the following:
169+
170+
- "zeros": Artifacts are replaced by zeros.
171+
172+
- "noise": The periods are filled with a gaussion noise that has the
173+
same variance that the one in the recordings, on a per channel
174+
basis
175+
**noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function
176+
177+
Returns
178+
-------
179+
silenced_recording : SilencedArtifactsRecording
180+
The recording extractor after silencing detected artifacts
181+
"""
182+
183+
_precomputable_kwarg_names = ["list_periods"]
184+
185+
def __init__(
186+
self,
187+
recording,
188+
detect_threshold=5,
189+
verbose=False,
190+
freq_max=20.0,
191+
min_duration_ms=50,
192+
mode="zeros",
193+
noise_levels=None,
194+
seed=None,
195+
list_periods=None,
196+
**noise_levels_kwargs,
197+
):
198+
199+
if list_periods is None:
200+
list_periods, _ = detect_period_artifacts_by_envelope(
201+
recording,
202+
detect_threshold=detect_threshold,
203+
min_duration_ms=min_duration_ms,
204+
freq_max=freq_max,
205+
seed=seed,
206+
noise_levels=noise_levels,
207+
**noise_levels_kwargs,
208+
)
209+
210+
if verbose:
211+
for i, periods in enumerate(list_periods):
212+
total_time = np.sum([end - start for start, end in periods])
213+
percentage = 100 * total_time / recording.get_num_samples(i)
214+
print(f"{percentage}% of segment {i} has been flagged as artifactual")
215+
216+
SilencedPeriodsRecording.__init__(
217+
self, recording, list_periods, mode=mode, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs
218+
)
219+
220+
221+
# function for API
222+
silence_artifacts = define_function_handling_dict_from_class(
223+
source_class=SilencedArtifactsRecording, name="silence_artifacts"
224+
)

src/spikeinterface/preprocessing/silence_periods.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,10 @@ def __init__(
5656
):
5757
available_modes = ("zeros", "noise")
5858
num_seg = recording.get_num_segments()
59-
6059
if num_seg == 1:
6160
if isinstance(list_periods, (list, np.ndarray)) and np.array(list_periods).ndim == 2:
62-
# when unique segment accept list instead of of list of list/arrays
61+
# when unique segment accept list instead of list of list/arrays
6362
list_periods = [list_periods]
64-
6563
# some checks
6664
assert mode in available_modes, f"mode {mode} is not an available mode: {available_modes}"
6765

@@ -80,10 +78,12 @@ def __init__(
8078

8179
if mode in ["noise"]:
8280
if noise_levels is None:
83-
noise_levels_kwargs = noise_levels_kwargs.copy()
84-
noise_levels_kwargs["return_in_uV"] = False
85-
noise_levels_kwargs["seed"] = seed
86-
noise_levels = get_noise_levels(recording, **noise_levels_kwargs)
81+
random_slices_kwargs = noise_levels_kwargs.pop("random_slices_kwargs", {}).copy()
82+
random_slices_kwargs["seed"] = seed
83+
noise_levels = get_noise_levels(
84+
recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs
85+
)
86+
8787
noise_generator = NoiseGeneratorRecording(
8888
num_channels=recording.get_num_channels(),
8989
sampling_frequency=recording.sampling_frequency,
@@ -121,8 +121,7 @@ def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg
121121
def get_traces(self, start_frame, end_frame, channel_indices):
122122
traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices)
123123
traces = traces.copy()
124-
125-
if len(self.periods) > 0:
124+
if self.periods.size > 0:
126125
new_interval = np.array([start_frame, end_frame])
127126
lower_index = np.searchsorted(self.periods[:, 1], new_interval[0])
128127
upper_index = np.searchsorted(self.periods[:, 0], new_interval[1])
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import pytest
2+
3+
import numpy as np
4+
5+
from spikeinterface.core import generate_recording
6+
from spikeinterface.preprocessing import silence_artifacts
7+
8+
9+
def test_silence_artifacts():
10+
# one segment only
11+
rec = generate_recording(durations=[10.0, 10])
12+
new_rec = silence_artifacts(rec, detect_threshold=5, freq_max=5.0, min_duration_ms=50)
13+
14+
15+
if __name__ == "__main__":
16+
test_silence_artifacts()

0 commit comments

Comments
 (0)