Skip to content

Commit 9c25668

Browse files
committed
saturation application with apodization
1 parent 99afff6 commit 9c25668

2 files changed

Lines changed: 18 additions & 7 deletions

File tree

src/spikeinterface/preprocessing/detect_artifacts.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,7 @@ def __init__(
717717
job_kwargs: dict | None = None,
718718
mode: Literal["zeros", "noise"] = "zeros",
719719
noise_levels_kwargs: dict | None = None,
720+
apodization: int = 7,
720721
seed: int | None = None,
721722
artifact_periods=None,
722723
) -> None:
@@ -729,7 +730,7 @@ def __init__(
729730
recording_to_detect, method=method, method_kwargs=method_kwargs, job_kwargs=job_kwargs
730731
)
731732
super().__init__(
732-
recording, periods=artifact_periods, mode=mode, noise_levels_kwargs=noise_levels_kwargs, seed=seed
733+
recording, periods=artifact_periods, mode=mode, noise_levels_kwargs=noise_levels_kwargs, seed=seed, apodization=apodization
733734
)
734735

735736
self._kwargs = dict(
@@ -742,6 +743,7 @@ def __init__(
742743
noise_levels_kwargs=noise_levels_kwargs,
743744
seed=seed,
744745
artifact_periods=artifact_periods,
746+
apodization=apodization,
745747
)
746748

747749

src/spikeinterface/preprocessing/silence_periods.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import scipy.signal
23

34
from spikeinterface.core.core_tools import define_function_handling_dict_from_class
45
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment
@@ -48,14 +49,15 @@ def __init__(
4849
self,
4950
recording,
5051
periods=None,
51-
# this is keep for backward compatibility
52+
# this is kept for backward compatibility
5253
list_periods=None,
5354
mode="zeros",
5455
noise_levels=None,
56+
apodization=7,
5557
seed=None,
5658
**noise_levels_kwargs,
5759
):
58-
available_modes = ("zeros", "noise")
60+
available_modes = ("zeros", "noise", "apodization")
5961
num_seg = recording.get_num_segments()
6062

6163
# handle backward compatibility with previous version
@@ -108,11 +110,11 @@ def __init__(
108110
i1 = seg_limits[seg_index + 1]
109111
periods_in_seg = periods[i0:i1]
110112
rec_segment = SilencedPeriodsRecordingSegment(
111-
parent_segment, periods_in_seg, mode, noise_generator, seg_index
113+
parent_segment, periods_in_seg, mode, noise_generator, seg_index, apodization=apodization,
112114
)
113115
self.add_recording_segment(rec_segment)
114116

115-
self._kwargs = dict(recording=recording, periods=periods, mode=mode, seed=seed, noise_levels=noise_levels)
117+
self._kwargs = dict(recording=recording, periods=periods, mode=mode, seed=seed, noise_levels=noise_levels, apodization=apodization)
116118

117119

118120
def _all_period_list_to_periods_vec(list_periods, num_seg):
@@ -154,12 +156,13 @@ def _check_periods(periods, num_seg):
154156

155157

156158
class SilencedPeriodsRecordingSegment(BasePreprocessorSegment):
157-
def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg_index):
159+
def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg_index, apodization=7):
158160
BasePreprocessorSegment.__init__(self, parent_recording_segment)
159161
self.periods = periods
160162
self.mode = mode
161163
self.seg_index = seg_index
162164
self.noise_generator = noise_generator
165+
self.apodization = apodization
163166

164167
def get_traces(self, start_frame, end_frame, channel_indices):
165168
traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices)
@@ -185,7 +188,13 @@ def get_traces(self, start_frame, end_frame, channel_indices):
185188
:, channel_indices
186189
]
187190
traces[onset:offset, :] = noise[onset:offset]
188-
191+
elif self.mode == "apodization":
192+
# apply a cosine taper to the saturation to create a mute function
193+
mute = np.zeros(traces.shape[0], dtype=np.float32)
194+
mute[onset:offset] = 1
195+
win = scipy.signal.windows.cosine(self.apodization)
196+
mute = np.maximum(0, 1 - scipy.signal.convolve(mute, win, mode="same"))
197+
traces = (traces.astype(np.float32) * mute[:, np.newaxis]).astype(traces.dtype)
189198
return traces
190199

191200

0 commit comments

Comments
 (0)