11import numpy as np
2+ import scipy .signal
23
34from spikeinterface .core .core_tools import define_function_handling_dict_from_class
45from .basepreprocessor import BasePreprocessor , BasePreprocessorSegment
@@ -48,14 +49,15 @@ def __init__(
4849 self ,
4950 recording ,
5051 periods = None ,
51- # this is keep for backward compatibility
52+ # this is kept for backward compatibility
5253 list_periods = None ,
5354 mode = "zeros" ,
5455 noise_levels = None ,
56+ apodization = 7 ,
5557 seed = None ,
5658 ** noise_levels_kwargs ,
5759 ):
58- available_modes = ("zeros" , "noise" )
60+ available_modes = ("zeros" , "noise" , "apodization" )
5961 num_seg = recording .get_num_segments ()
6062
6163 # handle backward compatibility with previous version
@@ -108,11 +110,11 @@ def __init__(
108110 i1 = seg_limits [seg_index + 1 ]
109111 periods_in_seg = periods [i0 :i1 ]
110112 rec_segment = SilencedPeriodsRecordingSegment (
111- parent_segment , periods_in_seg , mode , noise_generator , seg_index
113+ parent_segment , periods_in_seg , mode , noise_generator , seg_index , apodization = apodization ,
112114 )
113115 self .add_recording_segment (rec_segment )
114116
115- self ._kwargs = dict (recording = recording , periods = periods , mode = mode , seed = seed , noise_levels = noise_levels )
117+ self ._kwargs = dict (recording = recording , periods = periods , mode = mode , seed = seed , noise_levels = noise_levels , apodization = apodization )
116118
117119
118120def _all_period_list_to_periods_vec (list_periods , num_seg ):
@@ -154,12 +156,13 @@ def _check_periods(periods, num_seg):
154156
155157
156158class SilencedPeriodsRecordingSegment (BasePreprocessorSegment ):
157- def __init__ (self , parent_recording_segment , periods , mode , noise_generator , seg_index ):
159+ def __init__ (self , parent_recording_segment , periods , mode , noise_generator , seg_index , apodization = 7 ):
158160 BasePreprocessorSegment .__init__ (self , parent_recording_segment )
159161 self .periods = periods
160162 self .mode = mode
161163 self .seg_index = seg_index
162164 self .noise_generator = noise_generator
165+ self .apodization = apodization
163166
164167 def get_traces (self , start_frame , end_frame , channel_indices ):
165168 traces = self .parent_recording_segment .get_traces (start_frame , end_frame , channel_indices )
@@ -185,7 +188,13 @@ def get_traces(self, start_frame, end_frame, channel_indices):
185188 :, channel_indices
186189 ]
187190 traces [onset :offset , :] = noise [onset :offset ]
188-
191+ elif self .mode == "apodization" :
192+ # apply a cosine taper to the saturation to create a mute function
193+ mute = np .zeros (traces .shape [0 ], dtype = np .float32 )
194+ mute [onset :offset ] = 1
195+ win = scipy .signal .windows .cosine (self .apodization )
196+ mute = np .maximum (0 , 1 - scipy .signal .convolve (mute , win , mode = "same" ))
197+ traces = (traces .astype (np .float32 ) * mute [:, np .newaxis ]).astype (traces .dtype )
189198 return traces
190199
191200
0 commit comments