Skip to content

Commit e206e32

Browse files
committed
add sort_by_depth kwargs to RasterWidget
1 parent a3fe49c commit e206e32

1 file changed

Lines changed: 13 additions & 0 deletions

File tree

src/spikeinterface/widgets/rasters.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,8 @@ class RasterWidget(BaseRasterWidget):
363363
A sorting object. Deprecated.
364364
sorting_analyzer : SortingAnalyzer | None, default: None
365365
A sorting analyzer object. Deprecated.
366+
sort_by_depth: bool = False
367+
Wether or not to sort units by depth, default: False
366368
"""
367369

368370
def __init__(
@@ -375,6 +377,7 @@ def __init__(
375377
backend: str | None = None,
376378
sorting: BaseSorting | None = None,
377379
sorting_analyzer: SortingAnalyzer | None = None,
380+
sort_by_depth : bool = False,
378381
**backend_kwargs,
379382
):
380383
if sorting is not None:
@@ -387,13 +390,23 @@ def __init__(
387390
warn(deprecation_msg, category=DeprecationWarning, stacklevel=2)
388391
sorting_analyzer_or_sorting = sorting_analyzer
389392

393+
394+
390395
sorting = self.ensure_sorting(sorting_analyzer_or_sorting)
391396

392397
segment_indices = validate_segment_indices(segment_indices, sorting)
393398

394399
if unit_ids is None:
395400
unit_ids = sorting.unit_ids
396401

402+
if sort_by_depth :
403+
# print("hey")
404+
if not sorting_analyzer_or_sorting.has_extension("unit_locations"):
405+
raise AttributeError(f"'unit_locations' necessary for sort_by_depth is True")
406+
depths = sorting_analyzer_or_sorting.get_extension("unit_locations").get_data(outputs="numpy")
407+
s_args_depths = np.argsort(depths[:, 1])
408+
unit_ids = unit_ids[s_args_depths]
409+
397410
# Create dict of dicts structure
398411
spike_train_data = {}
399412
y_axis_data = {}

0 commit comments

Comments
 (0)