Skip to content

Commit b1a5665

Browse files
committed
change raster depth sorting logic for widget handling
1 parent 1919bb7 commit b1a5665

1 file changed

Lines changed: 57 additions & 12 deletions

File tree

src/spikeinterface/widgets/rasters.py

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def __init__(
6363
self,
6464
spike_train_data: dict,
6565
y_axis_data: dict,
66+
depth_dict: dict | None = None,
67+
sort_by_depth: bool = False,
6668
unit_ids: list | None = None,
6769
segment_indices: list | None = None,
6870
durations: list | None = None,
@@ -144,6 +146,8 @@ def __init__(
144146
plot_data = dict(
145147
spike_train_data=concatenated_spike_trains,
146148
y_axis_data=concatenated_y_axis,
149+
depth_dict=depth_dict,
150+
sort_by_depth=sort_by_depth,
147151
unit_ids=unit_ids,
148152
plot_histograms=plot_histograms,
149153
y_lim=y_lim,
@@ -200,14 +204,18 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
200204

201205
spike_train_data = dp.spike_train_data
202206
y_axis_data = dp.y_axis_data
203-
207+
204208
for unit_id in unit_ids:
205209
if unit_id not in spike_train_data:
206210
continue # Skip this unit if not in data
207211

208212
unit_spike_train = spike_train_data[unit_id][:: dp.scatter_decimate]
209213
unit_y_data = y_axis_data[unit_id][:: dp.scatter_decimate]
210214

215+
if dp.sort_by_depth and dp.depth_dict is not None:
216+
ones = np.ones_like(unit_y_data)
217+
unit_y_data = ones * dp.depth_dict[unit_id]
218+
211219
if dp.color_kwargs is None:
212220
scatter_ax.scatter(unit_spike_train, unit_y_data, s=1, label=unit_id, color=unit_colors[unit_id])
213221
else:
@@ -249,7 +257,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
249257
x_lim = [0, np.sum(dp.durations)]
250258
scatter_ax.set_xlim(x_lim)
251259

252-
if dp.y_ticks:
260+
if dp.sort_by_depth and dp.depth_dict is not None:
261+
scatter_ax.set_yticks(
262+
ticks=list(range(len(dp.depth_dict))),
263+
labels=list(dp.depth_dict.keys())
264+
)
265+
elif dp.y_ticks:
253266
scatter_ax.set_yticks(**dp.y_ticks)
254267

255268
scatter_ax.set_title(dp.title)
@@ -282,8 +295,13 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
282295
self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm))
283296
plt.show()
284297

285-
self.unit_selector = UnitSelector(list(data_plot["spike_train_data"].keys()))
286-
self.unit_selector.value = list(data_plot["spike_train_data"].keys())[:1]
298+
if data_plot["sort_by_depth"] and data_plot["depth_dict"] is not None:
299+
unit_list = list(data_plot["depth_dict"].keys())
300+
else:
301+
unit_list = (list(data_plot["spike_train_data"].keys()))
302+
303+
self.unit_selector = UnitSelector(unit_list)
304+
self.unit_selector.value = unit_list[:1]
287305

288306
children = [self.unit_selector]
289307

@@ -294,6 +312,13 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
294312
)
295313
children.append(self.checkbox_histograms)
296314

315+
if data_plot["depth_dict"] is not None:
316+
self.checkbox_depth = W.Checkbox(
317+
value=data_plot["sort_by_depth"],
318+
description="Sort by depth",
319+
)
320+
children.append(self.checkbox_depth)
321+
297322
left_sidebar = W.VBox(
298323
children=children,
299324
layout=W.Layout(align_items="center", width="100%", height="100%"),
@@ -311,16 +336,20 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
311336
self.unit_selector.observe(self._update_plot, names="value", type="change")
312337
if data_plot["plot_histograms"] is not None:
313338
self.checkbox_histograms.observe(self._full_update_plot, names="value", type="change")
339+
if data_plot["depth_dict"] is not None:
340+
self.checkbox_depth.observe(self._update_plot, names="value", type="change")
314341

315342
if backend_kwargs["display"]:
316343
display(self.widget)
317-
344+
318345
def _full_update_plot(self, change=None):
319346
self.figure.clear()
320347
data_plot = self.next_data_plot
321348
data_plot["unit_ids"] = self.unit_selector.value
322349
if data_plot["plot_histograms"] is not None:
323350
data_plot["plot_histograms"] = self.checkbox_histograms.value
351+
if data_plot["depth_dict"] is not None:
352+
data_plot["sort_by_depth"] = self.checkbox_depth.value
324353
data_plot["plot_legend"] = False
325354

326355
backend_kwargs = dict(figure=self.figure, axes=None, ax=None)
@@ -334,15 +363,27 @@ def _update_plot(self, change=None):
334363
data_plot["unit_ids"] = self.unit_selector.value
335364
if data_plot["plot_histograms"] is not None:
336365
data_plot["plot_histograms"] = self.checkbox_histograms.value
366+
if data_plot["depth_dict"] is not None:
367+
data_plot["sort_by_depth"] = self.checkbox_depth.value
368+
369+
if data_plot["sort_by_depth"] and data_plot["depth_dict"] is not None:
370+
unit_list = list(data_plot["depth_dict"].keys())
371+
else:
372+
unit_list = list(data_plot["spike_train_data"].keys())
373+
374+
old_value = self.unit_selector.value
375+
self.unit_selector.unit_ids = unit_list
376+
self.unit_selector.selector.options = unit_list
377+
self.unit_selector.value = old_value
378+
379+
data_plot["unit_ids"] = self.unit_selector.value
337380
data_plot["plot_legend"] = False
338381

339382
backend_kwargs = dict(figure=None, axes=self.axes, ax=None)
340383
self.plot_matplotlib(data_plot, **backend_kwargs)
341-
342384
self.figure.canvas.draw()
343385
self.figure.canvas.flush_events()
344386

345-
346387
class RasterWidget(BaseRasterWidget):
347388
"""
348389
Plots spike train rasters.
@@ -397,12 +438,14 @@ def __init__(
397438
if unit_ids is None:
398439
unit_ids = sorting.unit_ids
399440

400-
if sort_by_depth:
401-
if not sorting_analyzer_or_sorting.has_extension("unit_locations"):
441+
442+
if not sorting_analyzer_or_sorting.has_extension("unit_locations"):
443+
if sort_by_depth:
402444
raise AttributeError(f"'unit_locations' necessary for sort_by_depth is True")
403-
depths = sorting_analyzer_or_sorting.get_extension("unit_locations").get_data(outputs="numpy")
404-
s_args_depths = np.argsort(depths[:, 1])
405-
unit_ids = unit_ids[s_args_depths]
445+
depths = sorting_analyzer_or_sorting.get_extension("unit_locations").get_data(outputs="numpy")
446+
s_args_depths = np.argsort(depths[:, 1])
447+
depth_dict = {b:i for i,b in enumerate(unit_ids[s_args_depths].tolist())}
448+
# unit_ids = unit_ids[s_args_depths]
406449

407450
# Create dict of dicts structure
408451
spike_train_data = {}
@@ -445,6 +488,8 @@ def __init__(
445488
plot_data = dict(
446489
spike_train_data=spike_train_data,
447490
y_axis_data=y_axis_data,
491+
depth_dict=depth_dict,
492+
sort_by_depth=sort_by_depth,
448493
segment_indices=segment_indices,
449494
x_lim=time_range,
450495
y_label="Unit id",

0 commit comments

Comments
 (0)