@@ -83,13 +83,14 @@ def __init__(
8383 groups : list | None = None ,
8484 ref_channel_ids : list | str | int | None = None ,
8585 local_radius : tuple [float , float ] = (30.0 , 55.0 ),
86+ min_local_neighbors : int = 5 ,
8687 dtype : str | np .dtype | None = None ,
8788 ):
8889 num_chans = recording .get_num_channels ()
8990 neighbors = None
9091 # some checks
9192 if reference not in ("global" , "single" , "local" ):
92- raise ValueError ("'reference' must be either 'global', 'single' or 'local'" )
93+ raise ValueError ("'reference' must be either 'global', 'single', 'local'" )
9394 if operator not in ("median" , "average" ):
9495 raise ValueError ("'operator' must be either 'median', 'average'" )
9596
@@ -117,10 +118,10 @@ def __init__(
117118 closest_inds , dist = get_closest_channels (recording )
118119 neighbors = {}
119120 for i in range (num_chans ):
120- mask = (dist [i , :] > local_radius [0 ]) & (dist [i , :] <= local_radius [1 ])
121+ mask = dist [i , :] > local_radius [0 ]
122+ nn = np .cumsum (mask )
123+ mask &= (dist [i , :] <= local_radius [1 ]) | ((0 < nn ) & (nn <= min_local_neighbors ))
121124 neighbors [i ] = closest_inds [i , mask ]
122- assert len (neighbors [i ]) > 0 , "No reference channels available in the local annulus for selection."
123-
124125 dtype_ = fix_dtype (recording , dtype )
125126 BasePreprocessor .__init__ (self , recording , dtype = dtype_ )
126127
@@ -136,7 +137,14 @@ def __init__(
136137
137138 for parent_segment in recording ._recording_segments :
138139 rec_segment = CommonReferenceRecordingSegment (
139- parent_segment , reference , operator , group_indices , ref_channel_indices , local_radius , neighbors , dtype_
140+ parent_segment ,
141+ reference ,
142+ operator ,
143+ group_indices ,
144+ ref_channel_indices ,
145+ local_radius ,
146+ neighbors ,
147+ dtype_ ,
140148 )
141149 self .add_recording_segment (rec_segment )
142150
@@ -147,6 +155,7 @@ def __init__(
147155 operator = operator ,
148156 ref_channel_ids = ref_channel_ids ,
149157 local_radius = local_radius ,
158+ min_local_neighbors = min_local_neighbors ,
150159 dtype = dtype_ .str ,
151160 )
152161
0 commit comments