Skip to content

Commit 6f75a23

Browse files
authored
Merge pull request #4035 from jakeswann1/add-segment-plotting
Add multi-segment capability to BaseRasterWidget and children
2 parents b2396f2 + 1a5c6c9 commit 6f75a23

5 files changed

Lines changed: 459 additions & 99 deletions

File tree

src/spikeinterface/widgets/amplitudes.py

Lines changed: 85 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from .rasters import BaseRasterWidget
77
from .base import BaseWidget, to_attr
8-
from .utils import get_some_colors
8+
from .utils import get_some_colors, validate_segment_indices, get_segment_durations
99

1010
from spikeinterface.core.sortinganalyzer import SortingAnalyzer
1111

@@ -25,8 +25,9 @@ class AmplitudesWidget(BaseRasterWidget):
2525
unit_colors : dict | None, default: None
2626
Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted
2727
by matplotlib. If None, default colors are chosen using the `get_some_colors` function.
28-
segment_index : int or None, default: None
29-
The segment index (or None if mono-segment)
28+
segment_indices : list of int or None, default: None
29+
Segment index or indices to plot. If None and there are multiple segments, defaults to 0.
30+
If list, spike trains and amplitudes are concatenated across the specified segments.
3031
max_spikes_per_unit : int or None, default: None
3132
Number of max spikes per unit to display. Use None for all spikes
3233
y_lim : tuple or None, default: None
@@ -51,70 +52,104 @@ def __init__(
5152
sorting_analyzer: SortingAnalyzer,
5253
unit_ids=None,
5354
unit_colors=None,
54-
segment_index=None,
55+
segment_indices=None,
5556
max_spikes_per_unit=None,
5657
y_lim=None,
5758
scatter_decimate=1,
5859
hide_unit_selector=False,
5960
plot_histograms=False,
6061
bins=None,
6162
plot_legend=True,
63+
segment_index=None,
6264
backend=None,
6365
**backend_kwargs,
6466
):
67+
import warnings
68+
69+
# Handle deprecation of segment_index parameter
70+
if segment_index is not None:
71+
warnings.warn(
72+
"The 'segment_index' parameter is deprecated and will be removed in a future version. "
73+
"Use 'segment_indices' instead.",
74+
DeprecationWarning,
75+
stacklevel=2,
76+
)
77+
if segment_indices is None:
78+
if isinstance(segment_index, int):
79+
segment_indices = [segment_index]
80+
else:
81+
segment_indices = segment_index
6582

6683
sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer)
67-
6884
sorting = sorting_analyzer.sorting
6985
self.check_extensions(sorting_analyzer, "spike_amplitudes")
7086

87+
# Get amplitudes by segment
7188
amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data(outputs="by_unit")
7289

7390
if unit_ids is None:
7491
unit_ids = sorting.unit_ids
7592

76-
if sorting.get_num_segments() > 1:
77-
if segment_index is None:
78-
warn("More than one segment available! Using `segment_index = 0`.")
79-
segment_index = 0
80-
else:
81-
segment_index = 0
93+
# Handle segment_index input
94+
segment_indices = validate_segment_indices(segment_indices, sorting)
95+
96+
# Check for SortingView backend
97+
is_sortingview = backend == "sortingview"
98+
99+
# For SortingView, ensure we're only using a single segment
100+
if is_sortingview and len(segment_indices) > 1:
101+
warn("SortingView backend currently supports only single segment. Using first segment.")
102+
segment_indices = [segment_indices[0]]
82103

83-
amplitudes_segment = amplitudes[segment_index]
84-
total_duration = sorting_analyzer.get_num_samples(segment_index) / sorting_analyzer.sampling_frequency
104+
# Create multi-segment data structure (dict of dicts)
105+
spiketrains_by_segment = {}
106+
amplitudes_by_segment = {}
85107

86-
all_spiketrains = {
87-
unit_id: sorting.get_unit_spike_train(unit_id, segment_index=segment_index, return_times=True)
88-
for unit_id in sorting.unit_ids
89-
}
108+
for idx in segment_indices:
109+
amplitudes_segment = amplitudes[idx]
90110

91-
all_amplitudes = amplitudes_segment
111+
# Initialize for this segment
112+
spiketrains_by_segment[idx] = {}
113+
amplitudes_by_segment[idx] = {}
114+
115+
for unit_id in unit_ids:
116+
# Get spike times for this unit in this segment
117+
spike_times = sorting.get_unit_spike_train(unit_id, segment_index=idx, return_times=True)
118+
amps = amplitudes_segment[unit_id]
119+
120+
# Store data in dict of dicts format
121+
spiketrains_by_segment[idx][unit_id] = spike_times
122+
amplitudes_by_segment[idx][unit_id] = amps
123+
124+
# Apply max_spikes_per_unit limit if specified
92125
if max_spikes_per_unit is not None:
93-
spiketrains_to_plot = dict()
94-
amplitudes_to_plot = dict()
95-
for unit, st in all_spiketrains.items():
96-
amps = all_amplitudes[unit]
97-
if len(st) > max_spikes_per_unit:
98-
random_idxs = np.random.choice(len(st), size=max_spikes_per_unit, replace=False)
99-
spiketrains_to_plot[unit] = st[random_idxs]
100-
amplitudes_to_plot[unit] = amps[random_idxs]
101-
else:
102-
spiketrains_to_plot[unit] = st
103-
amplitudes_to_plot[unit] = amps
104-
else:
105-
spiketrains_to_plot = all_spiketrains
106-
amplitudes_to_plot = all_amplitudes
126+
for idx in segment_indices:
127+
for unit_id in unit_ids:
128+
st = spiketrains_by_segment[idx][unit_id]
129+
amps = amplitudes_by_segment[idx][unit_id]
130+
if len(st) > max_spikes_per_unit:
131+
# Scale down the number of spikes proportionally per segment
132+
# to ensure we have max_spikes_per_unit total after concatenation
133+
segment_count = len(segment_indices)
134+
segment_max = max(1, max_spikes_per_unit // segment_count)
135+
136+
if len(st) > segment_max:
137+
random_idxs = np.random.choice(len(st), size=segment_max, replace=False)
138+
spiketrains_by_segment[idx][unit_id] = st[random_idxs]
139+
amplitudes_by_segment[idx][unit_id] = amps[random_idxs]
107140

108141
if plot_histograms and bins is None:
109142
bins = 100
110143

144+
# Calculate durations for all segments for x-axis limits
145+
durations = get_segment_durations(sorting, segment_indices)
146+
147+
# Build the plot data with the full dict of dicts structure
111148
plot_data = dict(
112-
spike_train_data=spiketrains_to_plot,
113-
y_axis_data=amplitudes_to_plot,
114149
unit_colors=unit_colors,
115150
plot_histograms=plot_histograms,
116151
bins=bins,
117-
total_duration=total_duration,
152+
durations=durations,
118153
unit_ids=unit_ids,
119154
hide_unit_selector=hide_unit_selector,
120155
plot_legend=plot_legend,
@@ -123,6 +158,17 @@ def __init__(
123158
scatter_decimate=scatter_decimate,
124159
)
125160

161+
# If using SortingView, extract just the first segment's data as flat dicts
162+
if is_sortingview:
163+
first_segment = segment_indices[0]
164+
plot_data["spike_train_data"] = spiketrains_by_segment[first_segment]
165+
plot_data["y_axis_data"] = amplitudes_by_segment[first_segment]
166+
else:
167+
# Otherwise use the full dict of dicts structure with all segments
168+
plot_data["spike_train_data"] = spiketrains_by_segment
169+
plot_data["y_axis_data"] = amplitudes_by_segment
170+
plot_data["segment_indices"] = segment_indices
171+
126172
BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs)
127173

128174
def plot_sortingview(self, data_plot, **backend_kwargs):
@@ -143,7 +189,10 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
143189
]
144190

145191
self.view = vv.SpikeAmplitudes(
146-
start_time_sec=0, end_time_sec=dp.total_duration, plots=sa_items, hide_unit_selector=dp.hide_unit_selector
192+
start_time_sec=0,
193+
end_time_sec=np.sum(dp.durations),
194+
plots=sa_items,
195+
hide_unit_selector=dp.hide_unit_selector,
147196
)
148197

149198
self.url = handle_display_and_url(self, self.view, **backend_kwargs)

src/spikeinterface/widgets/motion.py

Lines changed: 89 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,14 @@ class DriftRasterMapWidget(BaseRasterWidget):
117117
"spike_locations" extension computed.
118118
direction : "x" or "y", default: "y"
119119
The direction to display. "y" is the depth direction.
120-
segment_index : int, default: None
121-
The segment index to display.
122120
recording : RecordingExtractor | None, default: None
123121
The recording extractor object (only used to get "real" times).
124-
segment_index : int, default: 0
125-
The segment index to display.
126122
sampling_frequency : float, default: None
127123
The sampling frequency (needed if recording is None).
124+
segment_indices : list of int or None, default: None
125+
The segment index or indices to display. If None and there's only one segment, it's used.
126+
If None and there are multiple segments, you must specify which to use.
127+
If a list of indices is provided, peaks and locations are concatenated across the segments.
128128
depth_lim : tuple or None, default: None
129129
The min and max depth to display, if None (min and max of the recording).
130130
scatter_decimate : int, default: None
@@ -149,25 +149,46 @@ def __init__(
149149
direction: str = "y",
150150
recording: BaseRecording | None = None,
151151
sampling_frequency: float | None = None,
152-
segment_index: int | None = None,
152+
segment_indices: list[int] | None = None,
153153
depth_lim: tuple[float, float] | None = None,
154154
color_amplitude: bool = True,
155155
scatter_decimate: int | None = None,
156156
cmap: str = "inferno",
157157
color: str = "Gray",
158158
clim: tuple[float, float] | None = None,
159159
alpha: float = 1,
160+
segment_index: int | list[int] | None = None,
160161
backend: str | None = None,
161162
**backend_kwargs,
162163
):
164+
import warnings
165+
from matplotlib.pyplot import colormaps
166+
from matplotlib.colors import Normalize
167+
168+
# Handle deprecation of segment_index parameter
169+
if segment_index is not None:
170+
warnings.warn(
171+
"The 'segment_index' parameter is deprecated and will be removed in a future version. "
172+
"Use 'segment_indices' instead.",
173+
DeprecationWarning,
174+
stacklevel=2,
175+
)
176+
if segment_indices is None:
177+
if isinstance(segment_index, int):
178+
segment_indices = [segment_index]
179+
else:
180+
segment_indices = segment_index
181+
163182
assert peaks is not None or sorting_analyzer is not None
183+
164184
if peaks is not None:
165185
assert peak_locations is not None
166186
if recording is None:
167187
assert sampling_frequency is not None, "If recording is None, you must provide the sampling frequency"
168188
else:
169189
sampling_frequency = recording.sampling_frequency
170190
peak_amplitudes = peaks["amplitude"]
191+
171192
if sorting_analyzer is not None:
172193
if sorting_analyzer.has_recording():
173194
recording = sorting_analyzer.recording
@@ -190,29 +211,56 @@ def __init__(
190211
else:
191212
peak_amplitudes = None
192213

193-
if segment_index is None:
194-
assert (
195-
len(np.unique(peaks["segment_index"])) == 1
196-
), "segment_index must be specified if there are multiple segments"
197-
segment_index = 0
198-
else:
199-
peak_mask = peaks["segment_index"] == segment_index
200-
peaks = peaks[peak_mask]
201-
peak_locations = peak_locations[peak_mask]
202-
if peak_amplitudes is not None:
203-
peak_amplitudes = peak_amplitudes[peak_mask]
204-
205-
from matplotlib.pyplot import colormaps
214+
unique_segments = np.unique(peaks["segment_index"])
206215

207-
if color_amplitude:
208-
amps = peak_amplitudes
216+
if segment_indices is None:
217+
if len(unique_segments) == 1:
218+
segment_indices = [int(unique_segments[0])]
219+
else:
220+
raise ValueError("segment_indices must be specified if there are multiple segments")
221+
222+
if not isinstance(segment_indices, list):
223+
raise ValueError("segment_indices must be a list of ints")
224+
225+
# Validate all segment indices exist in the data
226+
for idx in segment_indices:
227+
if idx not in unique_segments:
228+
raise ValueError(f"segment_index {idx} not found in peaks data")
229+
230+
# Filter data for the selected segments
231+
# Note: For simplicity, we'll filter all data first, then construct dict of dicts
232+
segment_mask = np.isin(peaks["segment_index"], segment_indices)
233+
filtered_peaks = peaks[segment_mask]
234+
filtered_locations = peak_locations[segment_mask]
235+
if peak_amplitudes is not None:
236+
filtered_amplitudes = peak_amplitudes[segment_mask]
237+
238+
# Create dict of dicts structure for the base class
239+
spike_train_data = {}
240+
y_axis_data = {}
241+
242+
# Process each segment separately
243+
for seg_idx in segment_indices:
244+
segment_mask = filtered_peaks["segment_index"] == seg_idx
245+
segment_peaks = filtered_peaks[segment_mask]
246+
segment_locations = filtered_locations[segment_mask]
247+
248+
# Convert peak times to seconds
249+
spike_times = segment_peaks["sample_index"] / sampling_frequency
250+
251+
# Store in dict of dicts format (using 0 as the "unit" id)
252+
spike_train_data[seg_idx] = {0: spike_times}
253+
y_axis_data[seg_idx] = {0: segment_locations[direction]}
254+
255+
if color_amplitude and peak_amplitudes is not None:
256+
amps = filtered_amplitudes
209257
amps_abs = np.abs(amps)
210258
q_95 = np.quantile(amps_abs, 0.95)
211-
cmap = colormaps[cmap]
259+
cmap_obj = colormaps[cmap]
212260
if clim is None:
213261
amps = amps_abs
214262
amps /= q_95
215-
c = cmap(amps)
263+
c = cmap_obj(amps)
216264
else:
217265
from matplotlib.colors import Normalize
218266

@@ -226,18 +274,30 @@ def __init__(
226274
else:
227275
color_kwargs = dict(color=color, c=None, alpha=alpha)
228276

229-
# convert data into format that `BaseRasterWidget` can take it in
230-
spike_train_data = {0: peaks["sample_index"] / sampling_frequency}
231-
y_axis_data = {0: peak_locations[direction]}
277+
# Calculate segment durations for x-axis limits
278+
if recording is not None:
279+
durations = [recording.get_duration(seg_idx) for seg_idx in segment_indices]
280+
else:
281+
# Find boundaries between segments using searchsorted
282+
segment_boundaries = [
283+
np.searchsorted(filtered_peaks["segment_index"], [seg_idx, seg_idx + 1]) for seg_idx in segment_indices
284+
]
285+
286+
# Calculate durations from max sample in each segment
287+
durations = [
288+
(filtered_peaks["sample_index"][end - 1] + 1) / sampling_frequency for (_, end) in segment_boundaries
289+
]
232290

233291
plot_data = dict(
234292
spike_train_data=spike_train_data,
235293
y_axis_data=y_axis_data,
294+
segment_indices=segment_indices,
236295
y_lim=depth_lim,
237296
color_kwargs=color_kwargs,
238297
scatter_decimate=scatter_decimate,
239298
title="Peak depth",
240299
y_label="Depth [um]",
300+
durations=durations,
241301
)
242302

243303
BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs)
@@ -370,10 +430,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
370430
dp.recording,
371431
)
372432

373-
commpon_drift_map_kwargs = dict(
433+
common_drift_map_kwargs = dict(
374434
direction=dp.motion.direction,
375435
recording=dp.recording,
376-
segment_index=dp.segment_index,
436+
segment_indices=[dp.segment_index],
377437
depth_lim=dp.depth_lim,
378438
scatter_decimate=dp.scatter_decimate,
379439
color_amplitude=dp.color_amplitude,
@@ -390,15 +450,15 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
390450
dp.peak_locations,
391451
ax=ax0,
392452
immediate_plot=True,
393-
**commpon_drift_map_kwargs,
453+
**common_drift_map_kwargs,
394454
)
395455

396456
_ = DriftRasterMapWidget(
397457
dp.peaks,
398458
corrected_location,
399459
ax=ax1,
400460
immediate_plot=True,
401-
**commpon_drift_map_kwargs,
461+
**common_drift_map_kwargs,
402462
)
403463

404464
ax2.plot(temporal_bins_s, displacement, alpha=0.2, color="black")

0 commit comments

Comments
 (0)