Skip to content

Commit 58d4eb7

Browse files
authored
Add valid periods regions in scatter plots (#254)
1 parent 2198323 commit 58d4eb7

3 files changed

Lines changed: 69 additions & 10 deletions

File tree

spikeinterface_gui/basescatterview.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ class BaseScatterView(ViewBase):
99
_depend_on = None
1010
_settings = [
1111
{'name': "auto_decimate", 'type': 'bool', 'value' : True},
12+
{'name': "display_valid_periods", 'type': 'bool', 'value' : True},
1213
{'name': 'max_spikes_per_unit', 'type': 'int', 'value' : 5_000},
1314
{'name': 'alpha', 'type': 'float', 'value' : 0.7, 'limits':(0, 1.), 'step':0.05},
1415
{'name': 'scatter_size', 'type': 'float', 'value' : 2., 'step':0.5},
@@ -39,6 +40,8 @@ def __init__(self, spike_data, y_label, controller=None, parent=None, backend="q
3940

4041
ViewBase.__init__(self, controller=controller, parent=parent, backend=backend)
4142

43+
self.valid_period_regions = []
44+
4245

4346
def get_unit_data(self, unit_id, segment_index=0):
4447
inds = self.controller.get_spike_indices(unit_id, segment_index=segment_index)
@@ -298,8 +301,8 @@ def _qt_refresh(self, set_scatter_range=False):
298301
all_inds = []
299302
ymins = []
300303
ymaxs = []
301-
for unit_id in self.controller.get_visible_unit_ids():
302-
304+
visible_units = self.controller.get_visible_unit_ids()
305+
for unit_id in visible_units:
303306
spike_times, spike_data, hist_count, hist_bins, ymin, ymax, inds = self.get_unit_data(
304307
unit_id,
305308
segment_index=segment_index
@@ -336,9 +339,28 @@ def _qt_refresh(self, set_scatter_range=False):
336339
self.viewBox2.setXRange(0, self._max_count, padding = 0.0)
337340

338341
# explicitly set the y-range of the histogram to match the spike data
339-
spike_times, spike_data = self.get_selected_spikes_data(segment_index=self.combo_seg.currentIndex(), visible_inds=all_inds)
342+
spike_times, spike_data = self.get_selected_spikes_data(segment_index=segment_index, visible_inds=all_inds)
340343
self.scatter_select.setData(spike_times, spike_data)
341344

345+
if self.settings["display_valid_periods"] and self.controller.valid_periods is not None:
346+
for region in self.valid_period_regions:
347+
self.plot.removeItem(region)
348+
self.valid_period_regions = []
349+
for unit_id in visible_units:
350+
valid_periods_unit = self.controller.valid_periods[segment_index][unit_id]
351+
color = self.get_unit_color(unit_id, alpha=0.2)
352+
pen_color = pg.mkColor(color)
353+
for period in valid_periods_unit:
354+
t_start = self.controller.sample_index_to_time(period[0])
355+
t_end = self.controller.sample_index_to_time(period[1])
356+
region = pg.LinearRegionItem([t_start, t_end], movable=False, brush=pen_color, pen=pen_color)
357+
self.plot.addItem(region, ignoreBounds=True)
358+
self.valid_period_regions.append(region)
359+
else:
360+
for region in self.valid_period_regions:
361+
self.plot.removeItem(region)
362+
self.valid_period_regions = []
363+
342364
def _qt_on_time_info_updated(self):
343365
if self.combo_seg.currentIndex() != self.controller.get_time()[1]:
344366
self._block_auto_refresh_and_notify = True
@@ -452,6 +474,17 @@ def _panel_make_layout(self):
452474
# Add SelectionGeometry event handler to capture lasso vertices
453475
self.scatter_fig.on_event('selectiongeometry', self._on_panel_selection_geometry)
454476

477+
self.valid_periods_source = ColumnDataSource(data=dict(
478+
left=[], right=[], top=[], bottom=[], fill_color=[]
479+
))
480+
self.scatter_fig.quad(
481+
left="left", right="right",
482+
top="top", bottom="bottom",
483+
fill_color="fill_color",
484+
line_color=None,
485+
source=self.valid_periods_source,
486+
)
487+
455488
self.hist_source = ColumnDataSource(data={"x": [], "y": []})
456489
self.hist_data_source = ColumnDataSource(data=dict(x=[], y=[], color=[]))
457490
self.hist_fig = bpl.figure(
@@ -496,8 +529,6 @@ def _panel_make_layout(self):
496529
),
497530
)
498531
)
499-
# self.hist_lines = []
500-
self.noise_harea = []
501532
self.plotted_inds = []
502533

503534
def _panel_refresh(self, set_scatter_range=False):
@@ -569,6 +600,24 @@ def _panel_refresh(self, set_scatter_range=False):
569600
# handle selected spikes
570601
self._panel_update_selected_spikes()
571602

603+
# Update valid period regions
604+
if self.settings["display_valid_periods"] and self.controller.valid_periods is not None:
605+
lefts, rights, tops, bottoms, colors = [], [], [], [], []
606+
for unit_id in visible_unit_ids:
607+
valid_periods_unit = self.controller.valid_periods[segment_index][unit_id]
608+
color_shade = self.get_unit_color(unit_id, alpha=0.2)
609+
for period in valid_periods_unit:
610+
lefts.append(self.controller.sample_index_to_time(period[0]))
611+
rights.append(self.controller.sample_index_to_time(period[1]))
612+
tops.append(1_000_000)
613+
bottoms.append(-1_000_000)
614+
colors.append(color_shade)
615+
self.valid_periods_source.data = dict(
616+
left=lefts, right=rights, top=tops, bottom=bottoms, fill_color=colors
617+
)
618+
else:
619+
self.valid_periods_source.data = dict(left=[], right=[], top=[], bottom=[], fill_color=[])
620+
572621
# Defer Range updates to avoid nested document lock issues
573622
# def update_ranges():
574623
if set_scatter_range or not self._first_refresh_done:
@@ -578,8 +627,6 @@ def _panel_refresh(self, set_scatter_range=False):
578627
self.hist_fig.x_range.end = max_count
579628
self.hist_fig.xaxis.ticker = FixedTicker(ticks=[0, max_count // 2, max_count])
580629

581-
# Schedule the update to run after the current event loop iteration
582-
# pn.state.execute(update_ranges, schedule=True)
583630

584631
def _panel_on_select_button(self, event):
585632
import panel as pn

spikeinterface_gui/controller.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,12 @@ def __init__(
249249
pc_ext = analyzer.get_extension('principal_components')
250250
self.pc_ext = pc_ext
251251

252+
if analyzer.has_extension("valid_unit_periods"):
253+
valid_periods_ext = analyzer.get_extension("valid_unit_periods")
254+
self.valid_periods = valid_periods_ext.get_data(outputs="by_unit")
255+
else:
256+
self.valid_periods = None
257+
252258
self._potential_merges = None
253259

254260
t1 = time.perf_counter()

spikeinterface_gui/view_base.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import time
22
from contextlib import contextmanager
33

4+
import numpy as np
5+
46
class ViewBase:
57
id: str = None
68
_supported_backend = []
@@ -159,7 +161,7 @@ def continue_from_user(self, warning_msg, action, *args):
159161
# Panel: asynchronous approach with callback
160162
self._panel_insert_warning_with_choice(warning_msg, action, *args)
161163

162-
def get_unit_color(self, unit_id):
164+
def get_unit_color(self, unit_id, alpha=1.0):
163165
if self.backend == "qt":
164166
from .myqt import QT
165167

@@ -170,14 +172,18 @@ def get_unit_color(self, unit_id):
170172
color = self.controller.get_unit_color(unit_id)
171173
r, g, b, a = color
172174
qcolor = QT.QColor(int(r * 255), int(g * 255), int(b * 255))
175+
# only cache
173176
self.controller._cached_qcolors[unit_id] = qcolor
174-
175-
return self.controller._cached_qcolors[unit_id]
177+
else:
178+
qcolor = self.controller._cached_qcolors[unit_id]
179+
qcolor.setAlpha(int(alpha * 255))
180+
return qcolor
176181

177182
elif self.backend == "panel":
178183
import matplotlib
179184

180185
color = self.controller.get_unit_color(unit_id)
186+
color = color[:3] + (np.float64(alpha),)
181187
html_color = matplotlib.colors.rgb2hex(color, keep_alpha=True)
182188
return html_color
183189

0 commit comments

Comments
 (0)