@@ -63,6 +63,8 @@ def __init__(
6363 self ,
6464 spike_train_data : dict ,
6565 y_axis_data : dict ,
66+ depth_dict : dict | None = None ,
67+ sort_by_depth : bool = False ,
6668 unit_ids : list | None = None ,
6769 segment_indices : list | None = None ,
6870 durations : list | None = None ,
@@ -144,6 +146,8 @@ def __init__(
144146 plot_data = dict (
145147 spike_train_data = concatenated_spike_trains ,
146148 y_axis_data = concatenated_y_axis ,
149+ depth_dict = depth_dict ,
150+ sort_by_depth = sort_by_depth ,
147151 unit_ids = unit_ids ,
148152 plot_histograms = plot_histograms ,
149153 y_lim = y_lim ,
@@ -200,14 +204,18 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
200204
201205 spike_train_data = dp .spike_train_data
202206 y_axis_data = dp .y_axis_data
203-
207+
204208 for unit_id in unit_ids :
205209 if unit_id not in spike_train_data :
206210 continue # Skip this unit if not in data
207211
208212 unit_spike_train = spike_train_data [unit_id ][:: dp .scatter_decimate ]
209213 unit_y_data = y_axis_data [unit_id ][:: dp .scatter_decimate ]
210214
215+ if dp .sort_by_depth and dp .depth_dict is not None :
216+ ones = np .ones_like (unit_y_data )
217+ unit_y_data = ones * dp .depth_dict [unit_id ]
218+
211219 if dp .color_kwargs is None :
212220 scatter_ax .scatter (unit_spike_train , unit_y_data , s = 1 , label = unit_id , color = unit_colors [unit_id ])
213221 else :
@@ -249,7 +257,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
249257 x_lim = [0 , np .sum (dp .durations )]
250258 scatter_ax .set_xlim (x_lim )
251259
252- if dp .y_ticks :
260+ if dp .sort_by_depth and dp .depth_dict is not None :
261+ scatter_ax .set_yticks (
262+ ticks = list (range (len (dp .depth_dict ))),
263+ labels = list (dp .depth_dict .keys ())
264+ )
265+ elif dp .y_ticks :
253266 scatter_ax .set_yticks (** dp .y_ticks )
254267
255268 scatter_ax .set_title (dp .title )
@@ -282,8 +295,13 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
282295 self .figure = plt .figure (figsize = ((ratios [1 ] * width_cm ) * cm , height_cm * cm ))
283296 plt .show ()
284297
285- self .unit_selector = UnitSelector (list (data_plot ["spike_train_data" ].keys ()))
286- self .unit_selector .value = list (data_plot ["spike_train_data" ].keys ())[:1 ]
298+ if data_plot ["sort_by_depth" ] and data_plot ["depth_dict" ] is not None :
299+ unit_list = list (data_plot ["depth_dict" ].keys ())
300+ else :
301+ unit_list = (list (data_plot ["spike_train_data" ].keys ()))
302+
303+ self .unit_selector = UnitSelector (unit_list )
304+ self .unit_selector .value = unit_list [:1 ]
287305
288306 children = [self .unit_selector ]
289307
@@ -294,6 +312,13 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
294312 )
295313 children .append (self .checkbox_histograms )
296314
315+ if data_plot ["depth_dict" ] is not None :
316+ self .checkbox_depth = W .Checkbox (
317+ value = data_plot ["sort_by_depth" ],
318+ description = "Sort by depth" ,
319+ )
320+ children .append (self .checkbox_depth )
321+
297322 left_sidebar = W .VBox (
298323 children = children ,
299324 layout = W .Layout (align_items = "center" , width = "100%" , height = "100%" ),
@@ -311,16 +336,20 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
311336 self .unit_selector .observe (self ._update_plot , names = "value" , type = "change" )
312337 if data_plot ["plot_histograms" ] is not None :
313338 self .checkbox_histograms .observe (self ._full_update_plot , names = "value" , type = "change" )
339+ if data_plot ["depth_dict" ] is not None :
340+ self .checkbox_depth .observe (self ._update_plot , names = "value" , type = "change" )
314341
315342 if backend_kwargs ["display" ]:
316343 display (self .widget )
317-
344+
318345 def _full_update_plot (self , change = None ):
319346 self .figure .clear ()
320347 data_plot = self .next_data_plot
321348 data_plot ["unit_ids" ] = self .unit_selector .value
322349 if data_plot ["plot_histograms" ] is not None :
323350 data_plot ["plot_histograms" ] = self .checkbox_histograms .value
351+ if data_plot ["depth_dict" ] is not None :
352+ data_plot ["sort_by_depth" ] = self .checkbox_depth .value
324353 data_plot ["plot_legend" ] = False
325354
326355 backend_kwargs = dict (figure = self .figure , axes = None , ax = None )
@@ -334,15 +363,27 @@ def _update_plot(self, change=None):
334363 data_plot ["unit_ids" ] = self .unit_selector .value
335364 if data_plot ["plot_histograms" ] is not None :
336365 data_plot ["plot_histograms" ] = self .checkbox_histograms .value
366+ if data_plot ["depth_dict" ] is not None :
367+ data_plot ["sort_by_depth" ] = self .checkbox_depth .value
368+
369+ if data_plot ["sort_by_depth" ] and data_plot ["depth_dict" ] is not None :
370+ unit_list = list (data_plot ["depth_dict" ].keys ())
371+ else :
372+ unit_list = list (data_plot ["spike_train_data" ].keys ())
373+
374+ old_value = self .unit_selector .value
375+ self .unit_selector .unit_ids = unit_list
376+ self .unit_selector .selector .options = unit_list
377+ self .unit_selector .value = old_value
378+
379+ data_plot ["unit_ids" ] = self .unit_selector .value
337380 data_plot ["plot_legend" ] = False
338381
339382 backend_kwargs = dict (figure = None , axes = self .axes , ax = None )
340383 self .plot_matplotlib (data_plot , ** backend_kwargs )
341-
342384 self .figure .canvas .draw ()
343385 self .figure .canvas .flush_events ()
344386
345-
346387class RasterWidget (BaseRasterWidget ):
347388 """
348389 Plots spike train rasters.
@@ -397,12 +438,14 @@ def __init__(
397438 if unit_ids is None :
398439 unit_ids = sorting .unit_ids
399440
400- if sort_by_depth :
401- if not sorting_analyzer_or_sorting .has_extension ("unit_locations" ):
441+
442+ if not sorting_analyzer_or_sorting .has_extension ("unit_locations" ):
443+ if sort_by_depth :
402444 raise AttributeError (f"'unit_locations' necessary for sort_by_depth is True" )
403- depths = sorting_analyzer_or_sorting .get_extension ("unit_locations" ).get_data (outputs = "numpy" )
404- s_args_depths = np .argsort (depths [:, 1 ])
405- unit_ids = unit_ids [s_args_depths ]
445+ depths = sorting_analyzer_or_sorting .get_extension ("unit_locations" ).get_data (outputs = "numpy" )
446+ s_args_depths = np .argsort (depths [:, 1 ])
447+ depth_dict = {b :i for i ,b in enumerate (unit_ids [s_args_depths ].tolist ())}
448+ # unit_ids = unit_ids[s_args_depths]
406449
407450 # Create dict of dicts structure
408451 spike_train_data = {}
@@ -445,6 +488,8 @@ def __init__(
445488 plot_data = dict (
446489 spike_train_data = spike_train_data ,
447490 y_axis_data = y_axis_data ,
491+ depth_dict = depth_dict ,
492+ sort_by_depth = sort_by_depth ,
448493 segment_indices = segment_indices ,
449494 x_lim = time_range ,
450495 y_label = "Unit id" ,
0 commit comments