@@ -9,6 +9,7 @@ class BaseScatterView(ViewBase):
99 _depend_on = None
1010 _settings = [
1111 {'name' : "auto_decimate" , 'type' : 'bool' , 'value' : True },
12+ {'name' : "display_valid_periods" , 'type' : 'bool' , 'value' : True },
1213 {'name' : 'max_spikes_per_unit' , 'type' : 'int' , 'value' : 5_000 },
1314 {'name' : 'alpha' , 'type' : 'float' , 'value' : 0.7 , 'limits' :(0 , 1. ), 'step' :0.05 },
1415 {'name' : 'scatter_size' , 'type' : 'float' , 'value' : 2. , 'step' :0.5 },
@@ -39,6 +40,8 @@ def __init__(self, spike_data, y_label, controller=None, parent=None, backend="q
3940
4041 ViewBase .__init__ (self , controller = controller , parent = parent , backend = backend )
4142
43+ self .valid_period_regions = []
44+
4245
4346 def get_unit_data (self , unit_id , segment_index = 0 ):
4447 inds = self .controller .get_spike_indices (unit_id , segment_index = segment_index )
@@ -298,8 +301,8 @@ def _qt_refresh(self, set_scatter_range=False):
298301 all_inds = []
299302 ymins = []
300303 ymaxs = []
301- for unit_id in self .controller .get_visible_unit_ids ():
302-
304+ visible_units = self .controller .get_visible_unit_ids ()
305+ for unit_id in visible_units :
303306 spike_times , spike_data , hist_count , hist_bins , ymin , ymax , inds = self .get_unit_data (
304307 unit_id ,
305308 segment_index = segment_index
@@ -336,9 +339,28 @@ def _qt_refresh(self, set_scatter_range=False):
336339 self .viewBox2 .setXRange (0 , self ._max_count , padding = 0.0 )
337340
338341 # explicitly set the y-range of the histogram to match the spike data
339- spike_times , spike_data = self .get_selected_spikes_data (segment_index = self . combo_seg . currentIndex () , visible_inds = all_inds )
342+ spike_times , spike_data = self .get_selected_spikes_data (segment_index = segment_index , visible_inds = all_inds )
340343 self .scatter_select .setData (spike_times , spike_data )
341344
345+ if self .settings ["display_valid_periods" ] and self .controller .valid_periods is not None :
346+ for region in self .valid_period_regions :
347+ self .plot .removeItem (region )
348+ self .valid_period_regions = []
349+ for unit_id in visible_units :
350+ valid_periods_unit = self .controller .valid_periods [segment_index ][unit_id ]
351+ color = self .get_unit_color (unit_id , alpha = 0.2 )
352+ pen_color = pg .mkColor (color )
353+ for period in valid_periods_unit :
354+ t_start = self .controller .sample_index_to_time (period [0 ])
355+ t_end = self .controller .sample_index_to_time (period [1 ])
356+ region = pg .LinearRegionItem ([t_start , t_end ], movable = False , brush = pen_color , pen = pen_color )
357+ self .plot .addItem (region , ignoreBounds = True )
358+ self .valid_period_regions .append (region )
359+ else :
360+ for region in self .valid_period_regions :
361+ self .plot .removeItem (region )
362+ self .valid_period_regions = []
363+
342364 def _qt_on_time_info_updated (self ):
343365 if self .combo_seg .currentIndex () != self .controller .get_time ()[1 ]:
344366 self ._block_auto_refresh_and_notify = True
@@ -452,6 +474,17 @@ def _panel_make_layout(self):
452474 # Add SelectionGeometry event handler to capture lasso vertices
453475 self .scatter_fig .on_event ('selectiongeometry' , self ._on_panel_selection_geometry )
454476
477+ self .valid_periods_source = ColumnDataSource (data = dict (
478+ left = [], right = [], top = [], bottom = [], fill_color = []
479+ ))
480+ self .scatter_fig .quad (
481+ left = "left" , right = "right" ,
482+ top = "top" , bottom = "bottom" ,
483+ fill_color = "fill_color" ,
484+ line_color = None ,
485+ source = self .valid_periods_source ,
486+ )
487+
455488 self .hist_source = ColumnDataSource (data = {"x" : [], "y" : []})
456489 self .hist_data_source = ColumnDataSource (data = dict (x = [], y = [], color = []))
457490 self .hist_fig = bpl .figure (
@@ -496,8 +529,6 @@ def _panel_make_layout(self):
496529 ),
497530 )
498531 )
499- # self.hist_lines = []
500- self .noise_harea = []
501532 self .plotted_inds = []
502533
503534 def _panel_refresh (self , set_scatter_range = False ):
@@ -569,6 +600,24 @@ def _panel_refresh(self, set_scatter_range=False):
569600 # handle selected spikes
570601 self ._panel_update_selected_spikes ()
571602
603+ # Update valid period regions
604+ if self .settings ["display_valid_periods" ] and self .controller .valid_periods is not None :
605+ lefts , rights , tops , bottoms , colors = [], [], [], [], []
606+ for unit_id in visible_unit_ids :
607+ valid_periods_unit = self .controller .valid_periods [segment_index ][unit_id ]
608+ color_shade = self .get_unit_color (unit_id , alpha = 0.2 )
609+ for period in valid_periods_unit :
610+ lefts .append (self .controller .sample_index_to_time (period [0 ]))
611+ rights .append (self .controller .sample_index_to_time (period [1 ]))
612+ tops .append (1_000_000 )
613+ bottoms .append (- 1_000_000 )
614+ colors .append (color_shade )
615+ self .valid_periods_source .data = dict (
616+ left = lefts , right = rights , top = tops , bottom = bottoms , fill_color = colors
617+ )
618+ else :
619+ self .valid_periods_source .data = dict (left = [], right = [], top = [], bottom = [], fill_color = [])
620+
572621 # Defer Range updates to avoid nested document lock issues
573622 # def update_ranges():
574623 if set_scatter_range or not self ._first_refresh_done :
@@ -578,8 +627,6 @@ def _panel_refresh(self, set_scatter_range=False):
578627 self .hist_fig .x_range .end = max_count
579628 self .hist_fig .xaxis .ticker = FixedTicker (ticks = [0 , max_count // 2 , max_count ])
580629
581- # Schedule the update to run after the current event loop iteration
582- # pn.state.execute(update_ranges, schedule=True)
583630
584631 def _panel_on_select_button (self , event ):
585632 import panel as pn
0 commit comments