Skip to content

Commit b2bd036

Browse files
tayheaupre-commit-ci[bot]chrishalcrowalejoe91
authored
Add sort_by_depth kwargs to RasterWidget (#4544)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> Co-authored-by: Alessio Buccino <alejoe9187@gmail.com>
1 parent 1ffd1c8 commit b2bd036

1 file changed

Lines changed: 60 additions & 4 deletions

File tree

src/spikeinterface/widgets/rasters.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def __init__(
6363
self,
6464
spike_train_data: dict,
6565
y_axis_data: dict,
66+
depth_dict: dict | None = None,
67+
sort_by_depth: bool = False,
6668
unit_ids: list | None = None,
6769
segment_indices: list | None = None,
6870
durations: list | None = None,
@@ -144,6 +146,8 @@ def __init__(
144146
plot_data = dict(
145147
spike_train_data=concatenated_spike_trains,
146148
y_axis_data=concatenated_y_axis,
149+
depth_dict=depth_dict,
150+
sort_by_depth=sort_by_depth,
147151
unit_ids=unit_ids,
148152
plot_histograms=plot_histograms,
149153
y_lim=y_lim,
@@ -208,6 +212,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
208212
unit_spike_train = spike_train_data[unit_id][:: dp.scatter_decimate]
209213
unit_y_data = y_axis_data[unit_id][:: dp.scatter_decimate]
210214

215+
if dp.sort_by_depth and dp.depth_dict is not None:
216+
ones = np.ones_like(unit_y_data)
217+
unit_y_data = ones * dp.depth_dict[unit_id]
218+
211219
if dp.color_kwargs is None:
212220
scatter_ax.scatter(unit_spike_train, unit_y_data, s=1, label=unit_id, color=unit_colors[unit_id])
213221
else:
@@ -249,7 +257,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
249257
x_lim = [0, np.sum(dp.durations)]
250258
scatter_ax.set_xlim(x_lim)
251259

252-
if dp.y_ticks:
260+
if dp.sort_by_depth and dp.depth_dict is not None:
261+
scatter_ax.set_yticks(ticks=list(range(len(dp.depth_dict))), labels=list(dp.depth_dict.keys()))
262+
elif dp.y_ticks:
253263
scatter_ax.set_yticks(**dp.y_ticks)
254264

255265
scatter_ax.set_title(dp.title)
@@ -282,8 +292,13 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
282292
self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm))
283293
plt.show()
284294

285-
self.unit_selector = UnitSelector(list(data_plot["spike_train_data"].keys()))
286-
self.unit_selector.value = list(data_plot["spike_train_data"].keys())[:1]
295+
if data_plot["sort_by_depth"] and data_plot["depth_dict"] is not None:
296+
unit_list = list(data_plot["depth_dict"].keys())
297+
else:
298+
unit_list = list(data_plot["spike_train_data"].keys())
299+
300+
self.unit_selector = UnitSelector(unit_list)
301+
self.unit_selector.value = unit_list[:1]
287302

288303
children = [self.unit_selector]
289304

@@ -294,6 +309,13 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
294309
)
295310
children.append(self.checkbox_histograms)
296311

312+
if data_plot["depth_dict"] is not None:
313+
self.checkbox_depth = W.Checkbox(
314+
value=data_plot["sort_by_depth"],
315+
description="Sort by depth",
316+
)
317+
children.append(self.checkbox_depth)
318+
297319
left_sidebar = W.VBox(
298320
children=children,
299321
layout=W.Layout(align_items="center", width="100%", height="100%"),
@@ -311,6 +333,8 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
311333
self.unit_selector.observe(self._update_plot, names="value", type="change")
312334
if data_plot["plot_histograms"] is not None:
313335
self.checkbox_histograms.observe(self._full_update_plot, names="value", type="change")
336+
if data_plot["depth_dict"] is not None:
337+
self.checkbox_depth.observe(self._update_plot, names="value", type="change")
314338

315339
if backend_kwargs["display"]:
316340
display(self.widget)
@@ -321,6 +345,8 @@ def _full_update_plot(self, change=None):
321345
data_plot["unit_ids"] = self.unit_selector.value
322346
if data_plot["plot_histograms"] is not None:
323347
data_plot["plot_histograms"] = self.checkbox_histograms.value
348+
if data_plot["depth_dict"] is not None:
349+
data_plot["sort_by_depth"] = self.checkbox_depth.value
324350
data_plot["plot_legend"] = False
325351

326352
backend_kwargs = dict(figure=self.figure, axes=None, ax=None)
@@ -334,11 +360,24 @@ def _update_plot(self, change=None):
334360
data_plot["unit_ids"] = self.unit_selector.value
335361
if data_plot["plot_histograms"] is not None:
336362
data_plot["plot_histograms"] = self.checkbox_histograms.value
363+
if data_plot["depth_dict"] is not None:
364+
data_plot["sort_by_depth"] = self.checkbox_depth.value
365+
366+
if data_plot["sort_by_depth"] and data_plot["depth_dict"] is not None:
367+
unit_list = list(data_plot["depth_dict"].keys())
368+
else:
369+
unit_list = list(data_plot["spike_train_data"].keys())
370+
371+
old_value = self.unit_selector.value
372+
self.unit_selector.unit_ids = unit_list
373+
self.unit_selector.selector.options = unit_list
374+
self.unit_selector.value = old_value
375+
376+
data_plot["unit_ids"] = self.unit_selector.value
337377
data_plot["plot_legend"] = False
338378

339379
backend_kwargs = dict(figure=None, axes=self.axes, ax=None)
340380
self.plot_matplotlib(data_plot, **backend_kwargs)
341-
342381
self.figure.canvas.draw()
343382
self.figure.canvas.flush_events()
344383

@@ -363,6 +402,8 @@ class RasterWidget(BaseRasterWidget):
363402
A sorting object. Deprecated.
364403
sorting_analyzer : SortingAnalyzer | None, default: None
365404
A sorting analyzer object. Deprecated.
405+
sort_by_depth : bool, default: False
406+
Whether or not to sort units by depth, only available when input is a `SortingAnalyzer`
366407
"""
367408

368409
def __init__(
@@ -375,6 +416,7 @@ def __init__(
375416
backend: str | None = None,
376417
sorting: BaseSorting | None = None,
377418
sorting_analyzer: SortingAnalyzer | None = None,
419+
sort_by_depth: bool = False,
378420
**backend_kwargs,
379421
):
380422
if sorting is not None:
@@ -394,6 +436,18 @@ def __init__(
394436
if unit_ids is None:
395437
unit_ids = sorting.unit_ids
396438

439+
depth_dict = None
440+
if isinstance(sorting_analyzer_or_sorting, SortingAnalyzer):
441+
if not sorting_analyzer_or_sorting.has_extension("unit_locations"):
442+
if sort_by_depth:
443+
raise AttributeError(f"'unit_locations' necessary for `sort_by_depth=True`")
444+
else:
445+
depths = sorting_analyzer_or_sorting.get_extension("unit_locations").get_data(outputs="numpy")
446+
s_args_depths = np.argsort(depths[:, 1])
447+
depth_dict = {b: i for i, b in enumerate(unit_ids[s_args_depths].tolist())}
448+
elif sort_by_depth:
449+
raise AttributeError("`sort_by_depth=True` requires a SortingAnalyzer")
450+
397451
# Create dict of dicts structure
398452
spike_train_data = {}
399453
y_axis_data = {}
@@ -435,6 +489,8 @@ def __init__(
435489
plot_data = dict(
436490
spike_train_data=spike_train_data,
437491
y_axis_data=y_axis_data,
492+
depth_dict=depth_dict,
493+
sort_by_depth=sort_by_depth,
438494
segment_indices=segment_indices,
439495
x_lim=time_range,
440496
y_label="Unit id",

0 commit comments

Comments
 (0)