Skip to content

Commit 1ceac4e

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent b1a5665 commit 1ceac4e

1 file changed

Lines changed: 7 additions & 10 deletions

File tree

src/spikeinterface/widgets/rasters.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
204204

205205
spike_train_data = dp.spike_train_data
206206
y_axis_data = dp.y_axis_data
207-
207+
208208
for unit_id in unit_ids:
209209
if unit_id not in spike_train_data:
210210
continue # Skip this unit if not in data
@@ -258,10 +258,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
258258
scatter_ax.set_xlim(x_lim)
259259

260260
if dp.sort_by_depth and dp.depth_dict is not None:
261-
scatter_ax.set_yticks(
262-
ticks=list(range(len(dp.depth_dict))),
263-
labels=list(dp.depth_dict.keys())
264-
)
261+
scatter_ax.set_yticks(ticks=list(range(len(dp.depth_dict))), labels=list(dp.depth_dict.keys()))
265262
elif dp.y_ticks:
266263
scatter_ax.set_yticks(**dp.y_ticks)
267264

@@ -298,7 +295,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
298295
if data_plot["sort_by_depth"] and data_plot["depth_dict"] is not None:
299296
unit_list = list(data_plot["depth_dict"].keys())
300297
else:
301-
unit_list = (list(data_plot["spike_train_data"].keys()))
298+
unit_list = list(data_plot["spike_train_data"].keys())
302299

303300
self.unit_selector = UnitSelector(unit_list)
304301
self.unit_selector.value = unit_list[:1]
@@ -341,7 +338,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs):
341338

342339
if backend_kwargs["display"]:
343340
display(self.widget)
344-
341+
345342
def _full_update_plot(self, change=None):
346343
self.figure.clear()
347344
data_plot = self.next_data_plot
@@ -370,7 +367,7 @@ def _update_plot(self, change=None):
370367
unit_list = list(data_plot["depth_dict"].keys())
371368
else:
372369
unit_list = list(data_plot["spike_train_data"].keys())
373-
370+
374371
old_value = self.unit_selector.value
375372
self.unit_selector.unit_ids = unit_list
376373
self.unit_selector.selector.options = unit_list
@@ -384,6 +381,7 @@ def _update_plot(self, change=None):
384381
self.figure.canvas.draw()
385382
self.figure.canvas.flush_events()
386383

384+
387385
class RasterWidget(BaseRasterWidget):
388386
"""
389387
Plots spike train rasters.
@@ -438,13 +436,12 @@ def __init__(
438436
if unit_ids is None:
439437
unit_ids = sorting.unit_ids
440438

441-
442439
if not sorting_analyzer_or_sorting.has_extension("unit_locations"):
443440
if sort_by_depth:
444441
raise AttributeError(f"'unit_locations' necessary for sort_by_depth is True")
445442
depths = sorting_analyzer_or_sorting.get_extension("unit_locations").get_data(outputs="numpy")
446443
s_args_depths = np.argsort(depths[:, 1])
447-
depth_dict = {b:i for i,b in enumerate(unit_ids[s_args_depths].tolist())}
444+
depth_dict = {b: i for i, b in enumerate(unit_ids[s_args_depths].tolist())}
448445
# unit_ids = unit_ids[s_args_depths]
449446

450447
# Create dict of dicts structure

0 commit comments

Comments
 (0)