Skip to content

Commit 540db00

Browse files
committed
minor fixes
1 parent cbc790c commit 540db00

3 files changed

Lines changed: 10 additions & 8 deletions

File tree

src/spikeinterface/widgets/amplitudes.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
# For SortingView, ensure we're only using a single segment
8484
if is_sortingview and len(segment_indices) > 1:
8585
warn("SortingView backend currently supports only single segment. Using first segment.")
86-
segment_indices = segment_indices[0]
86+
segment_indices = [segment_indices[0]]
8787

8888
# Create multi-segment data structure (dict of dicts)
8989
spiketrains_by_segment = {}
@@ -150,11 +150,13 @@ def __init__(
150150
first_segment = segment_indices[0]
151151
plot_data["spike_train_data"] = spiketrains_by_segment[first_segment]
152152
plot_data["y_axis_data"] = amplitudes_by_segment[first_segment]
153+
print(plot_data["spike_train_data"])
154+
print(plot_data["y_axis_data"])
153155
else:
154156
# Otherwise use the full dict of dicts structure with all segments
155157
plot_data["spike_train_data"] = spiketrains_by_segment
156158
plot_data["y_axis_data"] = amplitudes_by_segment
157-
plot_data["segment_index"] = segment_indices
159+
plot_data["segment_indices"] = segment_indices
158160

159161
BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs)
160162

src/spikeinterface/widgets/motion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,10 @@ def __init__(
201201
if len(unique_segments) == 1:
202202
segment_indices = [int(unique_segments[0])]
203203
else:
204-
raise ValueError("segment_index must be specified if there are multiple segments")
204+
raise ValueError("segment_indices must be specified if there are multiple segments")
205205

206206
if not isinstance(segment_indices, list):
207-
raise ValueError("segment_index must be an int or a list of ints")
207+
raise ValueError("segment_indices must be a list of ints")
208208

209209
# Validate all segment indices exist in the data
210210
for idx in segment_indices:
@@ -275,7 +275,7 @@ def __init__(
275275
plot_data = dict(
276276
spike_train_data=spike_train_data,
277277
y_axis_data=y_axis_data,
278-
segment_index=segment_indices,
278+
segment_indices=segment_indices,
279279
y_lim=depth_lim,
280280
color_kwargs=color_kwargs,
281281
scatter_decimate=scatter_decimate,
@@ -417,7 +417,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
417417
commpon_drift_map_kwargs = dict(
418418
direction=dp.motion.direction,
419419
recording=dp.recording,
420-
segment_index=dp.segment_index,
420+
segment_indices=list(dp.segment_index),
421421
depth_lim=dp.depth_lim,
422422
scatter_decimate=dp.scatter_decimate,
423423
color_amplitude=dp.color_amplitude,

src/spikeinterface/widgets/rasters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def __init__(
419419

420420
sorting = self.ensure_sorting(sorting)
421421

422-
segment_indices = validate_segment_indices(sorting, segment_indices)
422+
segment_indices = validate_segment_indices(segment_indices, sorting)
423423

424424
if unit_ids is None:
425425
unit_ids = sorting.unit_ids
@@ -479,7 +479,7 @@ def __init__(
479479
plot_data = dict(
480480
spike_train_data=spike_train_data,
481481
y_axis_data=y_axis_data,
482-
segment_index=segment_indices,
482+
segment_indices=segment_indices,
483483
x_lim=time_range,
484484
y_label="Unit id",
485485
unit_ids=unit_ids,

0 commit comments

Comments
 (0)