Skip to content

Commit cbc790c

Browse files
committed
Improve segment validation to list only. Add unitls function for validation
1 parent 468e564 commit cbc790c

4 files changed

Lines changed: 67 additions & 76 deletions

File tree

src/spikeinterface/widgets/amplitudes.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from .rasters import BaseRasterWidget
77
from .base import BaseWidget, to_attr
8-
from .utils import get_some_colors
8+
from .utils import get_some_colors, validate_segment_indices
99

1010
from spikeinterface.core.sortinganalyzer import SortingAnalyzer
1111

@@ -25,7 +25,7 @@ class AmplitudesWidget(BaseRasterWidget):
2525
unit_colors : dict | None, default: None
2626
Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted
2727
by matplotlib. If None, default colors are chosen using the `get_some_colors` function.
28-
segment_index : int or list of int or None, default: None
28+
segment_indices : list of int or None, default: None
2929
Segment index or indices to plot. If None and there are multiple segments, defaults to 0.
3030
If list, spike trains and amplitudes are concatenated across the specified segments.
3131
max_spikes_per_unit : int or None, default: None
@@ -52,7 +52,7 @@ def __init__(
5252
sorting_analyzer: SortingAnalyzer,
5353
unit_ids=None,
5454
unit_colors=None,
55-
segment_index=None,
55+
segment_indices=None,
5656
max_spikes_per_unit=None,
5757
y_lim=None,
5858
scatter_decimate=1,
@@ -74,38 +74,16 @@ def __init__(
7474
if unit_ids is None:
7575
unit_ids = sorting.unit_ids
7676

77-
num_segments = sorting.get_num_segments()
78-
7977
# Handle segment_index input
80-
if num_segments > 1:
81-
if segment_index is None:
82-
warn("More than one segment available! Using `segment_index = 0`.")
83-
segment_index = 0
84-
else:
85-
segment_index = 0
78+
segment_indices = validate_segment_indices(segment_indices, sorting)
8679

8780
# Check for SortingView backend
8881
is_sortingview = backend == "sortingview"
8982

9083
# For SortingView, ensure we're only using a single segment
91-
if is_sortingview and isinstance(segment_index, list) and len(segment_index) > 1:
84+
if is_sortingview and len(segment_indices) > 1:
9285
warn("SortingView backend currently supports only single segment. Using first segment.")
93-
segment_index = segment_index[0]
94-
95-
# Convert segment_index to list for consistent processing
96-
if isinstance(segment_index, int):
97-
segment_indices = [segment_index]
98-
elif isinstance(segment_index, list):
99-
segment_indices = segment_index
100-
else:
101-
raise ValueError("segment_index must be an int or a list of ints")
102-
103-
# Validate segment indices
104-
for idx in segment_indices:
105-
if not isinstance(idx, int):
106-
raise ValueError(f"Each segment index must be an integer, got {type(idx)}")
107-
if idx < 0 or idx >= num_segments:
108-
raise ValueError(f"segment_index {idx} out of range (0 to {num_segments - 1})")
86+
segment_indices = segment_indices[0]
10987

11088
# Create multi-segment data structure (dict of dicts)
11189
spiketrains_by_segment = {}

src/spikeinterface/widgets/motion.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class DriftRasterMapWidget(BaseRasterWidget):
121121
The recording extractor object (only used to get "real" times).
122122
sampling_frequency : float, default: None
123123
The sampling frequency (needed if recording is None).
124-
segment_index : int or list of int or None, default: None
124+
segment_indices : list of int or None, default: None
125125
The segment index or indices to display. If None and there's only one segment, it's used.
126126
If None and there are multiple segments, you must specify which to use.
127127
If a list of indices is provided, peaks and locations are concatenated across the segments.
@@ -149,7 +149,7 @@ def __init__(
149149
direction: str = "y",
150150
recording: BaseRecording | None = None,
151151
sampling_frequency: float | None = None,
152-
segment_index: int | list | None = None,
152+
segment_indices: list[int] | None = None,
153153
depth_lim: tuple[float, float] | None = None,
154154
color_amplitude: bool = True,
155155
scatter_decimate: int | None = None,
@@ -197,16 +197,13 @@ def __init__(
197197

198198
unique_segments = np.unique(peaks["segment_index"])
199199

200-
if segment_index is None:
200+
if segment_indices is None:
201201
if len(unique_segments) == 1:
202202
segment_indices = [int(unique_segments[0])]
203203
else:
204204
raise ValueError("segment_index must be specified if there are multiple segments")
205-
elif isinstance(segment_index, int):
206-
segment_indices = [segment_index]
207-
elif isinstance(segment_index, list):
208-
segment_indices = segment_index
209-
else:
205+
206+
if not isinstance(segment_indices, list):
210207
raise ValueError("segment_index must be an int or a list of ints")
211208

212209
# Validate all segment indices exist in the data

src/spikeinterface/widgets/rasters.py

Lines changed: 12 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from warnings import warn
55

66
from .base import BaseWidget, to_attr, default_backend_kwargs
7-
from .utils import get_some_colors
7+
from .utils import get_some_colors, validate_segment_indices
88

99

1010
class BaseRasterWidget(BaseWidget):
@@ -23,7 +23,7 @@ class BaseRasterWidget(BaseWidget):
2323
converted to a dict of dicts with segment 0.
2424
unit_ids : array-like | None, default: None
2525
List of unit_ids to plot
26-
segment_index : int | list | None, default: None
26+
segment_indices : list | None, default: None
2727
For multi-segment data, specifies which segment(s) to plot. If None, uses all available segments.
2828
For single-segment data, this parameter is ignored.
2929
total_duration : int | None, default: None
@@ -65,7 +65,7 @@ def __init__(
6565
spike_train_data: dict,
6666
y_axis_data: dict,
6767
unit_ids: list | None = None,
68-
segment_index: int | list | None = None,
68+
segment_indices: list | None = None,
6969
total_duration: int | None = None,
7070
plot_histograms: bool = False,
7171
bins: int | None = None,
@@ -93,22 +93,17 @@ def __init__(
9393
available_segments.sort() # Ensure consistent ordering
9494

9595
# Determine which segments to use
96-
if segment_index is None:
96+
if segment_indices is None:
9797
# Use all segments by default
9898
segments_to_use = available_segments
99-
elif isinstance(segment_index, int):
100-
# Single segment specified
101-
if segment_index not in available_segments:
102-
raise ValueError(f"segment_index {segment_index} not found in data")
103-
segments_to_use = [segment_index]
104-
elif isinstance(segment_index, list):
99+
elif isinstance(segment_indices, list):
105100
# Multiple segments specified
106-
for idx in segment_index:
101+
for idx in segment_indices:
107102
if idx not in available_segments:
108-
raise ValueError(f"segment_index {idx} not found in data")
109-
segments_to_use = segment_index
103+
raise ValueError(f"segment_index {idx} not found in avialable segments {available_segments}")
104+
segments_to_use = segment_indices
110105
else:
111-
raise ValueError("segment_index must be int, list, or None")
106+
raise ValueError("segment_index must be `list` or `None`")
112107

113108
# Get all unit IDs present in any segment if not specified
114109
if unit_ids is None:
@@ -391,7 +386,7 @@ class RasterWidget(BaseRasterWidget):
391386
A sorting object
392387
sorting_analyzer : SortingAnalyzer | None, default: None
393388
A sorting analyzer object
394-
segment_index : int or list of int or None, default: None
389+
segment_indices : list of int or None, default: None
395390
The segment index or indices to use. If None and there are multiple segments, defaults to 0.
396391
If a list of indices is provided, spike trains are concatenated across the specified segments.
397392
unit_ids : list
@@ -406,7 +401,7 @@ def __init__(
406401
self,
407402
sorting=None,
408403
sorting_analyzer=None,
409-
segment_index=None,
404+
segment_indices=None,
410405
unit_ids=None,
411406
time_range=None,
412407
color="k",
@@ -424,30 +419,7 @@ def __init__(
424419

425420
sorting = self.ensure_sorting(sorting)
426421

427-
num_segments = sorting.get_num_segments()
428-
429-
# Handle segment_index input
430-
if num_segments > 1:
431-
if segment_index is None:
432-
warn("More than one segment available! Using `segment_index = 0`.")
433-
segment_index = 0
434-
else:
435-
segment_index = 0
436-
437-
# Convert segment_index to list for consistent processing
438-
if isinstance(segment_index, int):
439-
segment_indices = [segment_index]
440-
elif isinstance(segment_index, list):
441-
segment_indices = segment_index
442-
else:
443-
raise ValueError("segment_index must be an int or a list of ints")
444-
445-
# Validate segment indices
446-
for idx in segment_indices:
447-
if not isinstance(idx, int):
448-
raise ValueError(f"Each segment index must be an integer, got {type(idx)}")
449-
if idx < 0 or idx >= num_segments:
450-
raise ValueError(f"segment_index {idx} out of range (0 to {num_segments - 1})")
422+
segment_indices = validate_segment_indices(sorting, segment_indices)
451423

452424
if unit_ids is None:
453425
unit_ids = sorting.unit_ids

src/spikeinterface/widgets/utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from warnings import warn
44
import numpy as np
55

6+
from spikeinterface.core import BaseSorting
67

78
def get_some_colors(
89
keys,
@@ -349,3 +350,46 @@ def make_units_table_from_analyzer(
349350
)
350351

351352
return units_table
353+
354+
def validate_segment_indices(segment_indices: list[int] | None, sorting: BaseSorting):
355+
"""
356+
Validate a list of segment indices for a sorting object.
357+
358+
Parameters
359+
----------
360+
segment_indices : list of int
361+
The segment index or indices to validate.
362+
sorting : BaseSorting
363+
The sorting object to validate against.
364+
365+
Returns
366+
-------
367+
list of int
368+
A list of valid segment indices.
369+
370+
Raises
371+
------
372+
ValueError
373+
If the segment indices are not valid.
374+
"""
375+
num_segments = sorting.get_num_segments()
376+
377+
# Handle segment_indices input
378+
if segment_indices is None:
379+
if num_segments > 1:
380+
warn("Segment indices not specified. Using first available segment only.")
381+
return [0]
382+
383+
# Convert segment_index to list for consistent processing
384+
if not isinstance(segment_indices, list):
385+
raise ValueError("segment_indices must be a list of ints - available segments are: " + list(range(num_segments)))
386+
387+
# Validate segment indices
388+
for idx in segment_indices:
389+
if not isinstance(idx, int):
390+
raise ValueError(f"Each segment index must be an integer, got {type(idx)}")
391+
if idx < 0 or idx >= num_segments:
392+
raise ValueError(f"segment_index {idx} out of range (0 to {num_segments - 1})")
393+
394+
395+
return segment_indices

0 commit comments

Comments
 (0)