@@ -204,7 +204,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
204204
205205 spike_train_data = dp .spike_train_data
206206 y_axis_data = dp .y_axis_data
207-
207+
208208 for unit_id in unit_ids :
209209 if unit_id not in spike_train_data :
210210 continue # Skip this unit if not in data
@@ -258,10 +258,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
258258 scatter_ax .set_xlim (x_lim )
259259
260260 if dp .sort_by_depth and dp .depth_dict is not None :
261- scatter_ax .set_yticks (
262- ticks = list (range (len (dp .depth_dict ))),
263- labels = list (dp .depth_dict .keys ())
264- )
261+ scatter_ax .set_yticks (ticks = list (range (len (dp .depth_dict ))), labels = list (dp .depth_dict .keys ()))
265262 elif dp .y_ticks :
266263 scatter_ax .set_yticks (** dp .y_ticks )
267264
@@ -298,7 +295,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
298295 if data_plot ["sort_by_depth" ] and data_plot ["depth_dict" ] is not None :
299296 unit_list = list (data_plot ["depth_dict" ].keys ())
300297 else :
301- unit_list = ( list (data_plot ["spike_train_data" ].keys () ))
298+ unit_list = list (data_plot ["spike_train_data" ].keys ())
302299
303300 self .unit_selector = UnitSelector (unit_list )
304301 self .unit_selector .value = unit_list [:1 ]
@@ -341,7 +338,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
341338
342339 if backend_kwargs ["display" ]:
343340 display (self .widget )
344-
341+
345342 def _full_update_plot (self , change = None ):
346343 self .figure .clear ()
347344 data_plot = self .next_data_plot
@@ -370,7 +367,7 @@ def _update_plot(self, change=None):
370367 unit_list = list (data_plot ["depth_dict" ].keys ())
371368 else :
372369 unit_list = list (data_plot ["spike_train_data" ].keys ())
373-
370+
374371 old_value = self .unit_selector .value
375372 self .unit_selector .unit_ids = unit_list
376373 self .unit_selector .selector .options = unit_list
@@ -384,6 +381,7 @@ def _update_plot(self, change=None):
384381 self .figure .canvas .draw ()
385382 self .figure .canvas .flush_events ()
386383
384+
387385class RasterWidget (BaseRasterWidget ):
388386 """
389387 Plots spike train rasters.
@@ -438,13 +436,12 @@ def __init__(
438436 if unit_ids is None :
439437 unit_ids = sorting .unit_ids
440438
441-
442439 if not sorting_analyzer_or_sorting .has_extension ("unit_locations" ):
443440 if sort_by_depth :
444441 raise AttributeError (f"'unit_locations' necessary for sort_by_depth is True" )
445442 depths = sorting_analyzer_or_sorting .get_extension ("unit_locations" ).get_data (outputs = "numpy" )
446443 s_args_depths = np .argsort (depths [:, 1 ])
447- depth_dict = {b :i for i ,b in enumerate (unit_ids [s_args_depths ].tolist ())}
444+ depth_dict = {b : i for i , b in enumerate (unit_ids [s_args_depths ].tolist ())}
448445 # unit_ids = unit_ids[s_args_depths]
449446
450447 # Create dict of dicts structure
0 commit comments