@@ -363,6 +363,8 @@ class RasterWidget(BaseRasterWidget):
363363 A sorting object. Deprecated.
364364 sorting_analyzer : SortingAnalyzer | None, default: None
365365 A sorting analyzer object. Deprecated.
366+ sort_by_depth: bool = False
367+ Wether or not to sort units by depth, default: False
366368 """
367369
368370 def __init__ (
@@ -375,6 +377,7 @@ def __init__(
375377 backend : str | None = None ,
376378 sorting : BaseSorting | None = None ,
377379 sorting_analyzer : SortingAnalyzer | None = None ,
380+ sort_by_depth : bool = False ,
378381 ** backend_kwargs ,
379382 ):
380383 if sorting is not None :
@@ -387,13 +390,23 @@ def __init__(
387390 warn (deprecation_msg , category = DeprecationWarning , stacklevel = 2 )
388391 sorting_analyzer_or_sorting = sorting_analyzer
389392
393+
394+
390395 sorting = self .ensure_sorting (sorting_analyzer_or_sorting )
391396
392397 segment_indices = validate_segment_indices (segment_indices , sorting )
393398
394399 if unit_ids is None :
395400 unit_ids = sorting .unit_ids
396401
402+ if sort_by_depth :
403+ # print("hey")
404+ if not sorting_analyzer_or_sorting .has_extension ("unit_locations" ):
405+ raise AttributeError (f"'unit_locations' necessary for sort_by_depth is True" )
406+ depths = sorting_analyzer_or_sorting .get_extension ("unit_locations" ).get_data (outputs = "numpy" )
407+ s_args_depths = np .argsort (depths [:, 1 ])
408+ unit_ids = unit_ids [s_args_depths ]
409+
397410 # Create dict of dicts structure
398411 spike_train_data = {}
399412 y_axis_data = {}
0 commit comments