@@ -117,14 +117,14 @@ class DriftRasterMapWidget(BaseRasterWidget):
117117 "spike_locations" extension computed.
118118 direction : "x" or "y", default: "y"
119119 The direction to display. "y" is the depth direction.
120- segment_index : int, default: None
121- The segment index to display.
122120 recording : RecordingExtractor | None, default: None
123121 The recording extractor object (only used to get "real" times).
124- segment_index : int, default: 0
125- The segment index to display.
126122 sampling_frequency : float, default: None
127123 The sampling frequency (needed if recording is None).
124+ segment_indices : list of int or None, default: None
125+ The segment index or indices to display. If None and there's only one segment, it's used.
126+ If None and there are multiple segments, you must specify which to use.
127+ If a list of indices is provided, peaks and locations are concatenated across the segments.
128128 depth_lim : tuple or None, default: None
129129 The min and max depth to display, if None (min and max of the recording).
130130 scatter_decimate : int, default: None
@@ -149,25 +149,46 @@ def __init__(
149149 direction : str = "y" ,
150150 recording : BaseRecording | None = None ,
151151 sampling_frequency : float | None = None ,
152- segment_index : int | None = None ,
152+ segment_indices : list [ int ] | None = None ,
153153 depth_lim : tuple [float , float ] | None = None ,
154154 color_amplitude : bool = True ,
155155 scatter_decimate : int | None = None ,
156156 cmap : str = "inferno" ,
157157 color : str = "Gray" ,
158158 clim : tuple [float , float ] | None = None ,
159159 alpha : float = 1 ,
160+ segment_index : int | list [int ] | None = None ,
160161 backend : str | None = None ,
161162 ** backend_kwargs ,
162163 ):
164+ import warnings
165+ from matplotlib .pyplot import colormaps
166+ from matplotlib .colors import Normalize
167+
168+ # Handle deprecation of segment_index parameter
169+ if segment_index is not None :
170+ warnings .warn (
171+ "The 'segment_index' parameter is deprecated and will be removed in a future version. "
172+ "Use 'segment_indices' instead." ,
173+ DeprecationWarning ,
174+ stacklevel = 2 ,
175+ )
176+ if segment_indices is None :
177+ if isinstance (segment_index , int ):
178+ segment_indices = [segment_index ]
179+ else :
180+ segment_indices = segment_index
181+
163182 assert peaks is not None or sorting_analyzer is not None
183+
164184 if peaks is not None :
165185 assert peak_locations is not None
166186 if recording is None :
167187 assert sampling_frequency is not None , "If recording is None, you must provide the sampling frequency"
168188 else :
169189 sampling_frequency = recording .sampling_frequency
170190 peak_amplitudes = peaks ["amplitude" ]
191+
171192 if sorting_analyzer is not None :
172193 if sorting_analyzer .has_recording ():
173194 recording = sorting_analyzer .recording
@@ -190,29 +211,56 @@ def __init__(
190211 else :
191212 peak_amplitudes = None
192213
193- if segment_index is None :
194- assert (
195- len (np .unique (peaks ["segment_index" ])) == 1
196- ), "segment_index must be specified if there are multiple segments"
197- segment_index = 0
198- else :
199- peak_mask = peaks ["segment_index" ] == segment_index
200- peaks = peaks [peak_mask ]
201- peak_locations = peak_locations [peak_mask ]
202- if peak_amplitudes is not None :
203- peak_amplitudes = peak_amplitudes [peak_mask ]
204-
205- from matplotlib .pyplot import colormaps
214+ unique_segments = np .unique (peaks ["segment_index" ])
206215
207- if color_amplitude :
208- amps = peak_amplitudes
216+ if segment_indices is None :
217+ if len (unique_segments ) == 1 :
218+ segment_indices = [int (unique_segments [0 ])]
219+ else :
220+ raise ValueError ("segment_indices must be specified if there are multiple segments" )
221+
222+ if not isinstance (segment_indices , list ):
223+ raise ValueError ("segment_indices must be a list of ints" )
224+
225+ # Validate all segment indices exist in the data
226+ for idx in segment_indices :
227+ if idx not in unique_segments :
228+ raise ValueError (f"segment_index { idx } not found in peaks data" )
229+
230+ # Filter data for the selected segments
231+ # Note: For simplicity, we'll filter all data first, then construct dict of dicts
232+ segment_mask = np .isin (peaks ["segment_index" ], segment_indices )
233+ filtered_peaks = peaks [segment_mask ]
234+ filtered_locations = peak_locations [segment_mask ]
235+ if peak_amplitudes is not None :
236+ filtered_amplitudes = peak_amplitudes [segment_mask ]
237+
238+ # Create dict of dicts structure for the base class
239+ spike_train_data = {}
240+ y_axis_data = {}
241+
242+ # Process each segment separately
243+ for seg_idx in segment_indices :
244+ segment_mask = filtered_peaks ["segment_index" ] == seg_idx
245+ segment_peaks = filtered_peaks [segment_mask ]
246+ segment_locations = filtered_locations [segment_mask ]
247+
248+ # Convert peak times to seconds
249+ spike_times = segment_peaks ["sample_index" ] / sampling_frequency
250+
251+ # Store in dict of dicts format (using 0 as the "unit" id)
252+ spike_train_data [seg_idx ] = {0 : spike_times }
253+ y_axis_data [seg_idx ] = {0 : segment_locations [direction ]}
254+
255+ if color_amplitude and peak_amplitudes is not None :
256+ amps = filtered_amplitudes
209257 amps_abs = np .abs (amps )
210258 q_95 = np .quantile (amps_abs , 0.95 )
211- cmap = colormaps [cmap ]
259+ cmap_obj = colormaps [cmap ]
212260 if clim is None :
213261 amps = amps_abs
214262 amps /= q_95
215- c = cmap (amps )
263+ c = cmap_obj (amps )
216264 else :
217265 from matplotlib .colors import Normalize
218266
@@ -226,18 +274,30 @@ def __init__(
226274 else :
227275 color_kwargs = dict (color = color , c = None , alpha = alpha )
228276
229- # convert data into format that `BaseRasterWidget` can take it in
230- spike_train_data = {0 : peaks ["sample_index" ] / sampling_frequency }
231- y_axis_data = {0 : peak_locations [direction ]}
277+ # Calculate segment durations for x-axis limits
278+ if recording is not None :
279+ durations = [recording .get_duration (seg_idx ) for seg_idx in segment_indices ]
280+ else :
281+ # Find boundaries between segments using searchsorted
282+ segment_boundaries = [
283+ np .searchsorted (filtered_peaks ["segment_index" ], [seg_idx , seg_idx + 1 ]) for seg_idx in segment_indices
284+ ]
285+
286+ # Calculate durations from max sample in each segment
287+ durations = [
288+ (filtered_peaks ["sample_index" ][end - 1 ] + 1 ) / sampling_frequency for (_ , end ) in segment_boundaries
289+ ]
232290
233291 plot_data = dict (
234292 spike_train_data = spike_train_data ,
235293 y_axis_data = y_axis_data ,
294+ segment_indices = segment_indices ,
236295 y_lim = depth_lim ,
237296 color_kwargs = color_kwargs ,
238297 scatter_decimate = scatter_decimate ,
239298 title = "Peak depth" ,
240299 y_label = "Depth [um]" ,
300+ durations = durations ,
241301 )
242302
243303 BaseRasterWidget .__init__ (self , ** plot_data , backend = backend , ** backend_kwargs )
@@ -370,10 +430,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
370430 dp .recording ,
371431 )
372432
373- commpon_drift_map_kwargs = dict (
433+ common_drift_map_kwargs = dict (
374434 direction = dp .motion .direction ,
375435 recording = dp .recording ,
376- segment_index = dp .segment_index ,
436+ segment_indices = [ dp .segment_index ] ,
377437 depth_lim = dp .depth_lim ,
378438 scatter_decimate = dp .scatter_decimate ,
379439 color_amplitude = dp .color_amplitude ,
@@ -390,15 +450,15 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
390450 dp .peak_locations ,
391451 ax = ax0 ,
392452 immediate_plot = True ,
393- ** commpon_drift_map_kwargs ,
453+ ** common_drift_map_kwargs ,
394454 )
395455
396456 _ = DriftRasterMapWidget (
397457 dp .peaks ,
398458 corrected_location ,
399459 ax = ax1 ,
400460 immediate_plot = True ,
401- ** commpon_drift_map_kwargs ,
461+ ** common_drift_map_kwargs ,
402462 )
403463
404464 ax2 .plot (temporal_bins_s , displacement , alpha = 0.2 , color = "black" )
0 commit comments