Skip to content

Commit afdea78

Browse files
committed
Update durations to use a list
1 parent 540db00 commit afdea78

3 files changed

Lines changed: 51 additions & 79 deletions

File tree

src/spikeinterface/widgets/amplitudes.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,18 @@ def __init__(
125125
if plot_histograms and bins is None:
126126
bins = 100
127127

128-
# Calculate total duration across all segments for x-axis limits
129-
total_duration = 0
128+
# Calculate durations for all segments for x-axis limits
129+
durations = []
130130
for idx in segment_indices:
131131
duration = sorting_analyzer.get_num_samples(idx) / sorting_analyzer.sampling_frequency
132-
total_duration += duration
132+
durations.append(duration)
133133

134134
# Build the plot data with the full dict of dicts structure
135135
plot_data = dict(
136136
unit_colors=unit_colors,
137137
plot_histograms=plot_histograms,
138138
bins=bins,
139-
total_duration=total_duration,
139+
durations=durations,
140140
unit_ids=unit_ids,
141141
hide_unit_selector=hide_unit_selector,
142142
plot_legend=plot_legend,
@@ -150,8 +150,6 @@ def __init__(
150150
first_segment = segment_indices[0]
151151
plot_data["spike_train_data"] = spiketrains_by_segment[first_segment]
152152
plot_data["y_axis_data"] = amplitudes_by_segment[first_segment]
153-
print(plot_data["spike_train_data"])
154-
print(plot_data["y_axis_data"])
155153
else:
156154
# Otherwise use the full dict of dicts structure with all segments
157155
plot_data["spike_train_data"] = spiketrains_by_segment
@@ -178,7 +176,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
178176
]
179177

180178
self.view = vv.SpikeAmplitudes(
181-
start_time_sec=0, end_time_sec=dp.total_duration, plots=sa_items, hide_unit_selector=dp.hide_unit_selector
179+
start_time_sec=0, end_time_sec=np.sum(dp.durations), plots=sa_items, hide_unit_selector=dp.hide_unit_selector
182180
)
183181

184182
self.url = handle_display_and_url(self, self.view, **backend_kwargs)

src/spikeinterface/widgets/motion.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,8 @@ def __init__(
256256
else:
257257
color_kwargs = dict(color=color, c=None, alpha=alpha)
258258

259-
# Calculate total duration for x-axis limits
260-
total_duration = 0
259+
# Calculate segment durations for x-axis limits
260+
durations = []
261261
for seg_idx in segment_indices:
262262
if recording is not None and hasattr(recording, "get_duration"):
263263
duration = recording.get_duration(seg_idx)
@@ -270,7 +270,7 @@ def __init__(
270270
duration = (max_sample + 1) / sampling_frequency
271271
else:
272272
duration = 0
273-
total_duration += duration
273+
durations.append(duration)
274274

275275
plot_data = dict(
276276
spike_train_data=spike_train_data,
@@ -281,7 +281,7 @@ def __init__(
281281
scatter_decimate=scatter_decimate,
282282
title="Peak depth",
283283
y_label="Depth [um]",
284-
total_duration=total_duration,
284+
durations=durations,
285285
)
286286

287287
BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs)
@@ -414,10 +414,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
414414
dp.recording,
415415
)
416416

417-
commpon_drift_map_kwargs = dict(
417+
common_drift_map_kwargs = dict(
418418
direction=dp.motion.direction,
419419
recording=dp.recording,
420-
segment_indices=list(dp.segment_index),
420+
segment_indices=[dp.segment_index],
421421
depth_lim=dp.depth_lim,
422422
scatter_decimate=dp.scatter_decimate,
423423
color_amplitude=dp.color_amplitude,
@@ -434,15 +434,15 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
434434
dp.peak_locations,
435435
ax=ax0,
436436
immediate_plot=True,
437-
**commpon_drift_map_kwargs,
437+
**common_drift_map_kwargs,
438438
)
439439

440440
_ = DriftRasterMapWidget(
441441
dp.peaks,
442442
corrected_location,
443443
ax=ax1,
444444
immediate_plot=True,
445-
**commpon_drift_map_kwargs,
445+
**common_drift_map_kwargs,
446446
)
447447

448448
ax2.plot(temporal_bins_s, displacement, alpha=0.2, color="black")

src/spikeinterface/widgets/rasters.py

Lines changed: 38 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ class BaseRasterWidget(BaseWidget):
2626
segment_indices : list | None, default: None
2727
For multi-segment data, specifies which segment(s) to plot. If None, uses all available segments.
2828
For single-segment data, this parameter is ignored.
29-
total_duration : int | None, default: None
30-
Duration of spike_train_data in seconds.
29+
durations : list | None, default: None
30+
List of durations per segment of spike_train_data in seconds.
3131
plot_histograms : bool, default: False
3232
Plot histogram of y-axis data in another subplot
3333
bins : int | None, default: None
@@ -66,7 +66,7 @@ def __init__(
6666
y_axis_data: dict,
6767
unit_ids: list | None = None,
6868
segment_indices: list | None = None,
69-
total_duration: int | None = None,
69+
durations: list | None = None,
7070
plot_histograms: bool = False,
7171
bins: int | None = None,
7272
scatter_decimate: int = 1,
@@ -112,67 +112,41 @@ def __init__(
112112
all_units.update(spike_train_data[seg_idx].keys())
113113
unit_ids = list(all_units)
114114

115-
# Calculate segment durations and boundaries
116-
segment_durations = []
117-
for seg_idx in segments_to_use:
118-
max_time = 0
119-
for unit_id in unit_ids:
120-
if unit_id in spike_train_data[seg_idx]:
121-
unit_times = spike_train_data[seg_idx][unit_id]
122-
if len(unit_times) > 0:
123-
max_time = max(max_time, np.max(unit_times))
124-
segment_durations.append(max_time)
125-
126115
# Calculate cumulative durations for segment boundaries
127-
cumulative_durations = [0]
128-
for duration in segment_durations[:-1]:
129-
cumulative_durations.append(cumulative_durations[-1] + duration)
130-
131-
# Segment boundaries for visualization (only internal boundaries)
132-
segment_boundaries = cumulative_durations[1:] if len(segments_to_use) > 1 else None
116+
segment_boundaries = np.cumsum(durations)
117+
cumulative_durations = np.concatenate([[0], segment_boundaries])
133118

134119
# Concatenate data across segments with proper time offsets
135-
concatenated_spike_trains = {unit_id: [] for unit_id in unit_ids}
136-
concatenated_y_axis = {unit_id: [] for unit_id in unit_ids}
137-
138-
for i, seg_idx in enumerate(segments_to_use):
139-
offset = cumulative_durations[i]
140-
141-
for unit_id in unit_ids:
142-
if unit_id in spike_train_data[seg_idx]:
143-
# Get spike times for this unit in this segment
144-
spike_times = spike_train_data[seg_idx][unit_id]
145-
146-
# Adjust spike times by adding cumulative duration of previous segments
147-
if offset > 0:
148-
adjusted_times = spike_times + offset
149-
else:
150-
adjusted_times = spike_times
151-
152-
# Get y-axis data for this unit in this segment
153-
y_values = y_axis_data[seg_idx][unit_id]
154-
155-
# Concatenate with any existing data
156-
if len(concatenated_spike_trains[unit_id]) > 0:
157-
concatenated_spike_trains[unit_id] = np.concatenate(
158-
[concatenated_spike_trains[unit_id], adjusted_times]
159-
)
160-
concatenated_y_axis[unit_id] = np.concatenate([concatenated_y_axis[unit_id], y_values])
161-
else:
162-
concatenated_spike_trains[unit_id] = adjusted_times
163-
concatenated_y_axis[unit_id] = y_values
164-
165-
# Update spike train and y-axis data with concatenated values
166-
processed_spike_train_data = concatenated_spike_trains
167-
processed_y_axis_data = concatenated_y_axis
168-
169-
# Calculate total duration from the data if not provided
170-
if total_duration is None:
171-
total_duration = cumulative_durations[-1] + segment_durations[-1]
120+
concatenated_spike_trains = {unit_id: np.array([]) for unit_id in unit_ids}
121+
concatenated_y_axis = {unit_id: np.array([]) for unit_id in unit_ids}
122+
123+
for offset, spike_train_segment, y_axis_segment in zip(
124+
cumulative_durations,
125+
[spike_train_data[idx] for idx in segments_to_use],
126+
[y_axis_data[idx] for idx in segments_to_use]
127+
):
128+
# Process each unit in the current segment
129+
for unit_id, spike_times in spike_train_segment.items():
130+
if unit_id not in unit_ids:
131+
continue
132+
133+
# Get y-axis values for this unit
134+
y_values = y_axis_segment[unit_id]
135+
136+
# Apply offset to spike times
137+
adjusted_times = spike_times + offset
138+
139+
# Add to concatenated data
140+
concatenated_spike_trains[unit_id] = np.concatenate(
141+
[concatenated_spike_trains[unit_id], adjusted_times]
142+
)
143+
concatenated_y_axis[unit_id] = np.concatenate(
144+
[concatenated_y_axis[unit_id], y_values]
145+
)
172146

173147
plot_data = dict(
174-
spike_train_data=processed_spike_train_data,
175-
y_axis_data=processed_y_axis_data,
148+
spike_train_data=concatenated_spike_trains,
149+
y_axis_data=concatenated_y_axis,
176150
unit_ids=unit_ids,
177151
plot_histograms=plot_histograms,
178152
y_lim=y_lim,
@@ -182,7 +156,7 @@ def __init__(
182156
unit_colors=unit_colors,
183157
y_label=y_label,
184158
title=title,
185-
total_duration=total_duration,
159+
durations=durations,
186160
plot_legend=plot_legend,
187161
bins=bins,
188162
y_ticks=y_ticks,
@@ -275,7 +249,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
275249
scatter_ax.set_ylim(*dp.y_lim)
276250
x_lim = dp.x_lim
277251
if x_lim is None:
278-
x_lim = [0, dp.total_duration]
252+
x_lim = [0, np.sum(dp.durations)]
279253
scatter_ax.set_xlim(x_lim)
280254

281255
if dp.y_ticks:
@@ -432,7 +406,7 @@ def __init__(
432406
unit_indices_map = {unit_id: i for i, unit_id in enumerate(unit_ids)}
433407

434408
# Calculate total duration across all segments
435-
total_duration = 0
409+
durations = []
436410
for seg_idx in segment_indices:
437411
# Try to get duration from recording if available
438412
if recording is not None:
@@ -446,7 +420,7 @@ def __init__(
446420
max_time = max(max_time, np.max(st))
447421
duration = max_time
448422

449-
total_duration += duration
423+
durations.append(duration)
450424

451425
# Initialize dicts for this segment
452426
spike_train_data[seg_idx] = {}
@@ -486,7 +460,7 @@ def __init__(
486460
unit_colors=unit_colors,
487461
plot_histograms=None,
488462
y_ticks=y_ticks,
489-
total_duration=total_duration,
463+
durations=durations,
490464
)
491465

492466
BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs)

0 commit comments

Comments
 (0)