Skip to content

Commit 33d3fec

Browse files
committed
handle BaseSorting error case
1 parent 1ceac4e commit 33d3fec

1 file changed

Lines changed: 10 additions & 7 deletions

File tree

src/spikeinterface/widgets/rasters.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -436,13 +436,16 @@ def __init__(
436436
if unit_ids is None:
437437
unit_ids = sorting.unit_ids
438438

439-
if not sorting_analyzer_or_sorting.has_extension("unit_locations"):
440-
if sort_by_depth:
441-
raise AttributeError(f"'unit_locations' necessary for sort_by_depth is True")
442-
depths = sorting_analyzer_or_sorting.get_extension("unit_locations").get_data(outputs="numpy")
443-
s_args_depths = np.argsort(depths[:, 1])
444-
depth_dict = {b: i for i, b in enumerate(unit_ids[s_args_depths].tolist())}
445-
# unit_ids = unit_ids[s_args_depths]
439+
if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer):
440+
if not sorting_analyzer_or_sorting.has_extension("unit_locations"):
441+
if sort_by_depth:
442+
raise AttributeError(f"'unit_locations' necessary for `sort_by_depth=True`")
443+
else:
444+
depths = sorting_analyzer_or_sorting.get_extension("unit_locations").get_data(outputs="numpy")
445+
s_args_depths = np.argsort(depths[:, 1])
446+
depth_dict = {b:i for i,b in enumerate(unit_ids[s_args_depths].tolist())}
447+
elif sort_by_depth:
448+
raise AttributeError("`sort_by_depth=True` requires a SortingAnalyzer")
446449

447450
# Create dict of dicts structure
448451
spike_train_data = {}

0 commit comments

Comments
 (0)