@@ -63,6 +63,8 @@ def __init__(
6363 self ,
6464 spike_train_data : dict ,
6565 y_axis_data : dict ,
66+ depth_dict : dict | None = None ,
67+ sort_by_depth : bool = False ,
6668 unit_ids : list | None = None ,
6769 segment_indices : list | None = None ,
6870 durations : list | None = None ,
@@ -144,6 +146,8 @@ def __init__(
144146 plot_data = dict (
145147 spike_train_data = concatenated_spike_trains ,
146148 y_axis_data = concatenated_y_axis ,
149+ depth_dict = depth_dict ,
150+ sort_by_depth = sort_by_depth ,
147151 unit_ids = unit_ids ,
148152 plot_histograms = plot_histograms ,
149153 y_lim = y_lim ,
@@ -208,6 +212,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
208212 unit_spike_train = spike_train_data [unit_id ][:: dp .scatter_decimate ]
209213 unit_y_data = y_axis_data [unit_id ][:: dp .scatter_decimate ]
210214
215+ if dp .sort_by_depth and dp .depth_dict is not None :
216+ ones = np .ones_like (unit_y_data )
217+ unit_y_data = ones * dp .depth_dict [unit_id ]
218+
211219 if dp .color_kwargs is None :
212220 scatter_ax .scatter (unit_spike_train , unit_y_data , s = 1 , label = unit_id , color = unit_colors [unit_id ])
213221 else :
@@ -249,7 +257,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
249257 x_lim = [0 , np .sum (dp .durations )]
250258 scatter_ax .set_xlim (x_lim )
251259
252- if dp .y_ticks :
260+ if dp .sort_by_depth and dp .depth_dict is not None :
261+ scatter_ax .set_yticks (ticks = list (range (len (dp .depth_dict ))), labels = list (dp .depth_dict .keys ()))
262+ elif dp .y_ticks :
253263 scatter_ax .set_yticks (** dp .y_ticks )
254264
255265 scatter_ax .set_title (dp .title )
@@ -282,8 +292,13 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
282292 self .figure = plt .figure (figsize = ((ratios [1 ] * width_cm ) * cm , height_cm * cm ))
283293 plt .show ()
284294
285- self .unit_selector = UnitSelector (list (data_plot ["spike_train_data" ].keys ()))
286- self .unit_selector .value = list (data_plot ["spike_train_data" ].keys ())[:1 ]
295+ if data_plot ["sort_by_depth" ] and data_plot ["depth_dict" ] is not None :
296+ unit_list = list (data_plot ["depth_dict" ].keys ())
297+ else :
298+ unit_list = list (data_plot ["spike_train_data" ].keys ())
299+
300+ self .unit_selector = UnitSelector (unit_list )
301+ self .unit_selector .value = unit_list [:1 ]
287302
288303 children = [self .unit_selector ]
289304
@@ -294,6 +309,13 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
294309 )
295310 children .append (self .checkbox_histograms )
296311
312+ if data_plot ["depth_dict" ] is not None :
313+ self .checkbox_depth = W .Checkbox (
314+ value = data_plot ["sort_by_depth" ],
315+ description = "Sort by depth" ,
316+ )
317+ children .append (self .checkbox_depth )
318+
297319 left_sidebar = W .VBox (
298320 children = children ,
299321 layout = W .Layout (align_items = "center" , width = "100%" , height = "100%" ),
@@ -311,6 +333,8 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
311333 self .unit_selector .observe (self ._update_plot , names = "value" , type = "change" )
312334 if data_plot ["plot_histograms" ] is not None :
313335 self .checkbox_histograms .observe (self ._full_update_plot , names = "value" , type = "change" )
336+ if data_plot ["depth_dict" ] is not None :
337+ self .checkbox_depth .observe (self ._update_plot , names = "value" , type = "change" )
314338
315339 if backend_kwargs ["display" ]:
316340 display (self .widget )
@@ -321,6 +345,8 @@ def _full_update_plot(self, change=None):
321345 data_plot ["unit_ids" ] = self .unit_selector .value
322346 if data_plot ["plot_histograms" ] is not None :
323347 data_plot ["plot_histograms" ] = self .checkbox_histograms .value
348+ if data_plot ["depth_dict" ] is not None :
349+ data_plot ["sort_by_depth" ] = self .checkbox_depth .value
324350 data_plot ["plot_legend" ] = False
325351
326352 backend_kwargs = dict (figure = self .figure , axes = None , ax = None )
@@ -334,11 +360,24 @@ def _update_plot(self, change=None):
334360 data_plot ["unit_ids" ] = self .unit_selector .value
335361 if data_plot ["plot_histograms" ] is not None :
336362 data_plot ["plot_histograms" ] = self .checkbox_histograms .value
363+ if data_plot ["depth_dict" ] is not None :
364+ data_plot ["sort_by_depth" ] = self .checkbox_depth .value
365+
366+ if data_plot ["sort_by_depth" ] and data_plot ["depth_dict" ] is not None :
367+ unit_list = list (data_plot ["depth_dict" ].keys ())
368+ else :
369+ unit_list = list (data_plot ["spike_train_data" ].keys ())
370+
371+ old_value = self .unit_selector .value
372+ self .unit_selector .unit_ids = unit_list
373+ self .unit_selector .selector .options = unit_list
374+ self .unit_selector .value = old_value
375+
376+ data_plot ["unit_ids" ] = self .unit_selector .value
337377 data_plot ["plot_legend" ] = False
338378
339379 backend_kwargs = dict (figure = None , axes = self .axes , ax = None )
340380 self .plot_matplotlib (data_plot , ** backend_kwargs )
341-
342381 self .figure .canvas .draw ()
343382 self .figure .canvas .flush_events ()
344383
@@ -363,6 +402,8 @@ class RasterWidget(BaseRasterWidget):
363402 A sorting object. Deprecated.
364403 sorting_analyzer : SortingAnalyzer | None, default: None
365404 A sorting analyzer object. Deprecated.
405+ sort_by_depth : bool, default: False
406+ Whether or not to sort units by depth, only available when input is a `SortingAnalyzer`
366407 """
367408
368409 def __init__ (
@@ -375,6 +416,7 @@ def __init__(
375416 backend : str | None = None ,
376417 sorting : BaseSorting | None = None ,
377418 sorting_analyzer : SortingAnalyzer | None = None ,
419+ sort_by_depth : bool = False ,
378420 ** backend_kwargs ,
379421 ):
380422 if sorting is not None :
@@ -394,6 +436,18 @@ def __init__(
394436 if unit_ids is None :
395437 unit_ids = sorting .unit_ids
396438
439+ depth_dict = None
440+ if isinstance (sorting_analyzer_or_sorting , SortingAnalyzer ):
441+ if not sorting_analyzer_or_sorting .has_extension ("unit_locations" ):
442+ if sort_by_depth :
443+ raise AttributeError (f"'unit_locations' necessary for `sort_by_depth=True`" )
444+ else :
445+ depths = sorting_analyzer_or_sorting .get_extension ("unit_locations" ).get_data (outputs = "numpy" )
446+ s_args_depths = np .argsort (depths [:, 1 ])
447+ depth_dict = {b : i for i , b in enumerate (unit_ids [s_args_depths ].tolist ())}
448+ elif sort_by_depth :
449+ raise AttributeError ("`sort_by_depth=True` requires a SortingAnalyzer" )
450+
397451 # Create dict of dicts structure
398452 spike_train_data = {}
399453 y_axis_data = {}
@@ -435,6 +489,8 @@ def __init__(
435489 plot_data = dict (
436490 spike_train_data = spike_train_data ,
437491 y_axis_data = y_axis_data ,
492+ depth_dict = depth_dict ,
493+ sort_by_depth = sort_by_depth ,
438494 segment_indices = segment_indices ,
439495 x_lim = time_range ,
440496 y_label = "Unit id" ,
0 commit comments