@@ -26,8 +26,8 @@ class BaseRasterWidget(BaseWidget):
2626 segment_indices : list | None, default: None
2727 For multi-segment data, specifies which segment(s) to plot. If None, uses all available segments.
2828 For single-segment data, this parameter is ignored.
29- total_duration : int | None, default: None
30- Duration of spike_train_data in seconds.
29+ durations : list | None, default: None
30+ List of durations per segment of spike_train_data in seconds.
3131 plot_histograms : bool, default: False
3232 Plot histogram of y-axis data in another subplot
3333 bins : int | None, default: None
@@ -66,7 +66,7 @@ def __init__(
6666 y_axis_data : dict ,
6767 unit_ids : list | None = None ,
6868 segment_indices : list | None = None ,
69- total_duration : int | None = None ,
69+ durations : list | None = None ,
7070 plot_histograms : bool = False ,
7171 bins : int | None = None ,
7272 scatter_decimate : int = 1 ,
@@ -112,67 +112,41 @@ def __init__(
112112 all_units .update (spike_train_data [seg_idx ].keys ())
113113 unit_ids = list (all_units )
114114
115- # Calculate segment durations and boundaries
116- segment_durations = []
117- for seg_idx in segments_to_use :
118- max_time = 0
119- for unit_id in unit_ids :
120- if unit_id in spike_train_data [seg_idx ]:
121- unit_times = spike_train_data [seg_idx ][unit_id ]
122- if len (unit_times ) > 0 :
123- max_time = max (max_time , np .max (unit_times ))
124- segment_durations .append (max_time )
125-
126115 # Calculate cumulative durations for segment boundaries
127- cumulative_durations = [0 ]
128- for duration in segment_durations [:- 1 ]:
129- cumulative_durations .append (cumulative_durations [- 1 ] + duration )
130-
131- # Segment boundaries for visualization (only internal boundaries)
132- segment_boundaries = cumulative_durations [1 :] if len (segments_to_use ) > 1 else None
116+ segment_boundaries = np .cumsum (durations )
117+ cumulative_durations = np .concatenate ([[0 ], segment_boundaries ])
133118
134119 # Concatenate data across segments with proper time offsets
135- concatenated_spike_trains = {unit_id : [] for unit_id in unit_ids }
136- concatenated_y_axis = {unit_id : [] for unit_id in unit_ids }
137-
138- for i , seg_idx in enumerate (segments_to_use ):
139- offset = cumulative_durations [i ]
140-
141- for unit_id in unit_ids :
142- if unit_id in spike_train_data [seg_idx ]:
143- # Get spike times for this unit in this segment
144- spike_times = spike_train_data [seg_idx ][unit_id ]
145-
146- # Adjust spike times by adding cumulative duration of previous segments
147- if offset > 0 :
148- adjusted_times = spike_times + offset
149- else :
150- adjusted_times = spike_times
151-
152- # Get y-axis data for this unit in this segment
153- y_values = y_axis_data [seg_idx ][unit_id ]
154-
155- # Concatenate with any existing data
156- if len (concatenated_spike_trains [unit_id ]) > 0 :
157- concatenated_spike_trains [unit_id ] = np .concatenate (
158- [concatenated_spike_trains [unit_id ], adjusted_times ]
159- )
160- concatenated_y_axis [unit_id ] = np .concatenate ([concatenated_y_axis [unit_id ], y_values ])
161- else :
162- concatenated_spike_trains [unit_id ] = adjusted_times
163- concatenated_y_axis [unit_id ] = y_values
164-
165- # Update spike train and y-axis data with concatenated values
166- processed_spike_train_data = concatenated_spike_trains
167- processed_y_axis_data = concatenated_y_axis
168-
169- # Calculate total duration from the data if not provided
170- if total_duration is None :
171- total_duration = cumulative_durations [- 1 ] + segment_durations [- 1 ]
120+ concatenated_spike_trains = {unit_id : np .array ([]) for unit_id in unit_ids }
121+ concatenated_y_axis = {unit_id : np .array ([]) for unit_id in unit_ids }
122+
123+ for offset , spike_train_segment , y_axis_segment in zip (
124+ cumulative_durations ,
125+ [spike_train_data [idx ] for idx in segments_to_use ],
126+ [y_axis_data [idx ] for idx in segments_to_use ]
127+ ):
128+ # Process each unit in the current segment
129+ for unit_id , spike_times in spike_train_segment .items ():
130+ if unit_id not in unit_ids :
131+ continue
132+
133+ # Get y-axis values for this unit
134+ y_values = y_axis_segment [unit_id ]
135+
136+ # Apply offset to spike times
137+ adjusted_times = spike_times + offset
138+
139+ # Add to concatenated data
140+ concatenated_spike_trains [unit_id ] = np .concatenate (
141+ [concatenated_spike_trains [unit_id ], adjusted_times ]
142+ )
143+ concatenated_y_axis [unit_id ] = np .concatenate (
144+ [concatenated_y_axis [unit_id ], y_values ]
145+ )
172146
173147 plot_data = dict (
174- spike_train_data = processed_spike_train_data ,
175- y_axis_data = processed_y_axis_data ,
148+ spike_train_data = concatenated_spike_trains ,
149+ y_axis_data = concatenated_y_axis ,
176150 unit_ids = unit_ids ,
177151 plot_histograms = plot_histograms ,
178152 y_lim = y_lim ,
@@ -182,7 +156,7 @@ def __init__(
182156 unit_colors = unit_colors ,
183157 y_label = y_label ,
184158 title = title ,
185- total_duration = total_duration ,
159+ durations = durations ,
186160 plot_legend = plot_legend ,
187161 bins = bins ,
188162 y_ticks = y_ticks ,
@@ -275,7 +249,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
275249 scatter_ax .set_ylim (* dp .y_lim )
276250 x_lim = dp .x_lim
277251 if x_lim is None :
278- x_lim = [0 , dp .total_duration ]
252+ x_lim = [0 , np . sum ( dp .durations ) ]
279253 scatter_ax .set_xlim (x_lim )
280254
281255 if dp .y_ticks :
@@ -432,7 +406,7 @@ def __init__(
432406 unit_indices_map = {unit_id : i for i , unit_id in enumerate (unit_ids )}
433407
434408 # Calculate total duration across all segments
435- total_duration = 0
409+ durations = []
436410 for seg_idx in segment_indices :
437411 # Try to get duration from recording if available
438412 if recording is not None :
@@ -446,7 +420,7 @@ def __init__(
446420 max_time = max (max_time , np .max (st ))
447421 duration = max_time
448422
449- total_duration += duration
423+ durations . append ( duration )
450424
451425 # Initialize dicts for this segment
452426 spike_train_data [seg_idx ] = {}
@@ -486,7 +460,7 @@ def __init__(
486460 unit_colors = unit_colors ,
487461 plot_histograms = None ,
488462 y_ticks = y_ticks ,
489- total_duration = total_duration ,
463+ durations = durations ,
490464 )
491465
492466 BaseRasterWidget .__init__ (self , ** plot_data , backend = backend , ** backend_kwargs )
0 commit comments