From 4e655933b0f97325598f49c61c593ce6dad4df85 Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Tue, 25 Mar 2025 15:21:15 +0000 Subject: [PATCH 01/17] Add multi-segment support for amplitudes widget --- src/spikeinterface/widgets/amplitudes.py | 84 +++++++++++++++++++----- 1 file changed, 67 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 197fefbab2..fb3bfbd4db 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -73,34 +73,84 @@ def __init__( if unit_ids is None: unit_ids = sorting.unit_ids - if sorting.get_num_segments() > 1: + num_segments = sorting.get_num_segments() + + # Handle segment_index input + if num_segments > 1: if segment_index is None: warn("More than one segment available! Using `segment_index = 0`.") segment_index = 0 else: segment_index = 0 + + # Convert segment_index to list for consistent processing + if isinstance(segment_index, int): + segment_indices = [segment_index] + elif isinstance(segment_index, list): + segment_indices = segment_index + else: + raise ValueError("segment_index must be an int or a list of ints") + + # Validate segment indices + for idx in segment_indices: + if not isinstance(idx, int): + raise ValueError(f"Each segment index must be an integer, got {type(idx)}") + if idx < 0 or idx >= num_segments: + raise ValueError(f"segment_index {idx} out of range (0 to {num_segments - 1})") + + # Initialize dictionaries for concatenated data + all_spiketrains = {unit_id: [] for unit_id in unit_ids} + all_amplitudes = {unit_id: [] for unit_id in unit_ids} + + # Calculate cumulative durations for spike time adjustments + cumulative_durations = [0] + for i in range(len(segment_indices) - 1): + segment_idx = segment_indices[i] + duration = sorting_analyzer.get_num_samples(segment_idx) / sorting_analyzer.sampling_frequency + cumulative_durations.append(cumulative_durations[-1] + duration) + + # Calculate total duration across all segments + total_duration = cumulative_durations[-1] + if segment_indices: # Check if there are any segments + total_duration += sorting_analyzer.get_num_samples(segment_indices[-1]) / sorting_analyzer.sampling_frequency + + # Concatenate spike trains and amplitudes across segments + for i, segment_idx in enumerate(segment_indices): + amplitudes_segment = amplitudes[segment_idx] + offset = cumulative_durations[i] + + for unit_id in unit_ids: + # Get spike times for this unit in this segment + spike_times = sorting.get_unit_spike_train(unit_id, segment_index=segment_idx, return_times=True) + + # Adjust spike times by adding cumulative duration of previous segments + if offset > 0: + spike_times = spike_times + offset + + # Get amplitudes for this unit in this segment + amps = amplitudes_segment[unit_id] + + # Concatenate with any existing data + if len(all_spiketrains[unit_id]) > 0: + all_spiketrains[unit_id] = np.concatenate([all_spiketrains[unit_id], spike_times]) + all_amplitudes[unit_id] = np.concatenate([all_amplitudes[unit_id], amps]) + else: + all_spiketrains[unit_id] = spike_times + all_amplitudes[unit_id] = amps - amplitudes_segment = amplitudes[segment_index] - total_duration = sorting_analyzer.get_num_samples(segment_index) / sorting_analyzer.sampling_frequency - - all_spiketrains = { - unit_id: sorting.get_unit_spike_train(unit_id, segment_index=segment_index, return_times=True) - for unit_id in sorting.unit_ids - } - - all_amplitudes = amplitudes_segment if max_spikes_per_unit is not None: spiketrains_to_plot = dict() amplitudes_to_plot = dict() - for unit, st in all_spiketrains.items(): - amps = all_amplitudes[unit] + for unit_id in unit_ids: + st = all_spiketrains[unit_id] + amps = all_amplitudes[unit_id] if len(st) > max_spikes_per_unit: random_idxs = np.random.choice(len(st), size=max_spikes_per_unit, replace=False) - spiketrains_to_plot[unit] = st[random_idxs] - amplitudes_to_plot[unit] = amps[random_idxs] + spiketrains_to_plot[unit_id] = st[random_idxs] + amplitudes_to_plot[unit_id] = amps[random_idxs] else: - spiketrains_to_plot[unit] = st - amplitudes_to_plot[unit] = amps + spiketrains_to_plot[unit_id] = st + amplitudes_to_plot[unit_id] = amps else: spiketrains_to_plot = all_spiketrains amplitudes_to_plot = all_amplitudes @@ -124,7 +174,7 @@ def __init__( ) BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs) - + def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url From 2539a8948345743297bae037d44ebc22303ff8a7 Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Tue, 25 Mar 2025 17:28:55 +0000 Subject: [PATCH 02/17] Update base raster widget and children to handle multi-segment --- src/spikeinterface/widgets/amplitudes.py | 96 +++++------ src/spikeinterface/widgets/motion.py | 105 +++++++++--- src/spikeinterface/widgets/rasters.py | 209 ++++++++++++++++++++--- 3 files changed, 303 insertions(+), 107 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index fb3bfbd4db..3b99f4dd50 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -25,8 +25,9 @@ class AmplitudesWidget(BaseRasterWidget): unit_colors : dict | None, default: None Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted by matplotlib. If None, default colors are chosen using the `get_some_colors` function. - segment_index : int or None, default: None - The segment index (or None if mono-segment) + segment_index : int or list of int or None, default: None + Segment index or indices to plot. If None and there are multiple segments, defaults to 0. + If list, spike trains and amplitudes are concatenated across the specified segments. max_spikes_per_unit : int or None, default: None Number of max spikes per unit to display. Use None for all spikes y_lim : tuple or None, default: None @@ -64,10 +65,10 @@ def __init__( ): sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) - sorting = sorting_analyzer.sorting self.check_extensions(sorting_analyzer, "spike_amplitudes") + # Get amplitudes by segment amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data(outputs="by_unit") if unit_ids is None: @@ -98,70 +99,57 @@ def __init__( if idx < 0 or idx >= num_segments: raise ValueError(f"segment_index {idx} out of range (0 to {num_segments - 1})") - # Initialize dictionaries for concatenated data - all_spiketrains = {unit_id: [] for unit_id in unit_ids} - all_amplitudes = {unit_id: [] for unit_id in unit_ids} - - # Calculate cumulative durations for spike time adjustments - cumulative_durations = [0] - for i in range(len(segment_indices) - 1): - segment_idx = segment_indices[i] - duration = sorting_analyzer.get_num_samples(segment_idx) / sorting_analyzer.sampling_frequency - cumulative_durations.append(cumulative_durations[-1] + duration) + # Create multi-segment data structure (dict of dicts) + spiketrains_by_segment = {} + amplitudes_by_segment = {} - # Calculate total duration across all segments - total_duration = cumulative_durations[-1] - if segment_indices: # Check if there are any segments - total_duration += sorting_analyzer.get_num_samples(segment_indices[-1]) / sorting_analyzer.sampling_frequency - - # Concatenate spike trains and amplitudes across segments - for i, segment_idx in enumerate(segment_indices): - amplitudes_segment = amplitudes[segment_idx] - offset = cumulative_durations[i] + for idx in segment_indices: + amplitudes_segment = amplitudes[idx] + + # Initialize for this segment + spiketrains_by_segment[idx] = {} + amplitudes_by_segment[idx] = {} for unit_id in unit_ids: # Get spike times for this unit in this segment - spike_times = sorting.get_unit_spike_train(unit_id, segment_index=segment_idx, return_times=True) - - # Adjust spike times by adding cumulative duration of previous segments - if offset > 0: - spike_times = spike_times + offset - - # Get amplitudes for this unit in this segment + spike_times = sorting.get_unit_spike_train(unit_id, segment_index=idx, return_times=True) amps = amplitudes_segment[unit_id] - # Concatenate with any existing data - if len(all_spiketrains[unit_id]) > 0: - all_spiketrains[unit_id] = np.concatenate([all_spiketrains[unit_id], spike_times]) - all_amplitudes[unit_id] = np.concatenate([all_amplitudes[unit_id], amps]) - else: - all_spiketrains[unit_id] = spike_times - all_amplitudes[unit_id] = amps - + # Store data in dict of dicts format + spiketrains_by_segment[idx][unit_id] = spike_times + amplitudes_by_segment[idx][unit_id] = amps + + # Apply max_spikes_per_unit limit if specified if max_spikes_per_unit is not None: - spiketrains_to_plot = dict() - amplitudes_to_plot = dict() - for unit_id in unit_ids: - st = all_spiketrains[unit_id] - amps = all_amplitudes[unit_id] - if len(st) > max_spikes_per_unit: - random_idxs = np.random.choice(len(st), size=max_spikes_per_unit, replace=False) - spiketrains_to_plot[unit_id] = st[random_idxs] - amplitudes_to_plot[unit_id] = amps[random_idxs] - else: - spiketrains_to_plot[unit_id] = st - amplitudes_to_plot[unit_id] = amps - else: - spiketrains_to_plot = all_spiketrains - amplitudes_to_plot = all_amplitudes + for idx in segment_indices: + for unit_id in unit_ids: + st = spiketrains_by_segment[idx][unit_id] + amps = amplitudes_by_segment[idx][unit_id] + if len(st) > max_spikes_per_unit: + # Scale down the number of spikes proportionally per segment + # to ensure we have max_spikes_per_unit total after concatenation + segment_count = len(segment_indices) + segment_max = max(1, max_spikes_per_unit // segment_count) + + if len(st) > segment_max: + random_idxs = np.random.choice(len(st), size=segment_max, replace=False) + spiketrains_by_segment[idx][unit_id] = st[random_idxs] + amplitudes_by_segment[idx][unit_id] = amps[random_idxs] if plot_histograms and bins is None: bins = 100 + # Calculate total duration across all segments for x-axis limits + total_duration = 0 + for idx in segment_indices: + duration = sorting_analyzer.get_num_samples(idx) / sorting_analyzer.sampling_frequency + total_duration += duration + plot_data = dict( - spike_train_data=spiketrains_to_plot, - y_axis_data=amplitudes_to_plot, + spike_train_data=spiketrains_by_segment, + y_axis_data=amplitudes_by_segment, unit_colors=unit_colors, + segment_index=segment_indices, plot_histograms=plot_histograms, bins=bins, total_duration=total_duration, diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index a0c7e1e28c..564373b704 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -117,14 +117,14 @@ class DriftRasterMapWidget(BaseRasterWidget): "spike_locations" extension computed. direction : "x" or "y", default: "y" The direction to display. "y" is the depth direction. - segment_index : int, default: None - The segment index to display. recording : RecordingExtractor | None, default: None The recording extractor object (only used to get "real" times). - segment_index : int, default: 0 - The segment index to display. sampling_frequency : float, default: None The sampling frequency (needed if recording is None). + segment_index : int or list of int or None, default: None + The segment index or indices to display. If None and there's only one segment, it's used. + If None and there are multiple segments, you must specify which to use. + If a list of indices is provided, peaks and locations are concatenated across the segments. depth_lim : tuple or None, default: None The min and max depth to display, if None (min and max of the recording). scatter_decimate : int, default: None @@ -149,7 +149,7 @@ def __init__( direction: str = "y", recording: BaseRecording | None = None, sampling_frequency: float | None = None, - segment_index: int | None = None, + segment_index: int | list | None = None, depth_lim: tuple[float, float] | None = None, color_amplitude: bool = True, scatter_decimate: int | None = None, @@ -160,7 +160,11 @@ def __init__( backend: str | None = None, **backend_kwargs, ): + from matplotlib.pyplot import colormaps + from matplotlib.colors import Normalize + assert peaks is not None or sorting_analyzer is not None + if peaks is not None: assert peak_locations is not None if recording is None: @@ -168,6 +172,7 @@ def __init__( else: sampling_frequency = recording.sampling_frequency peak_amplitudes = peaks["amplitude"] + if sorting_analyzer is not None: if sorting_analyzer.has_recording(): recording = sorting_analyzer.recording @@ -190,32 +195,62 @@ def __init__( else: peak_amplitudes = None + unique_segments = np.unique(peaks["segment_index"]) + if segment_index is None: - assert ( - len(np.unique(peaks["segment_index"])) == 1 - ), "segment_index must be specified if there are multiple segments" - segment_index = 0 + if len(unique_segments) == 1: + segment_indices = [int(unique_segments[0])] + else: + raise ValueError("segment_index must be specified if there are multiple segments") + elif isinstance(segment_index, int): + segment_indices = [segment_index] + elif isinstance(segment_index, list): + segment_indices = segment_index else: - peak_mask = peaks["segment_index"] == segment_index - peaks = peaks[peak_mask] - peak_locations = peak_locations[peak_mask] - if peak_amplitudes is not None: - peak_amplitudes = peak_amplitudes[peak_mask] - - from matplotlib.pyplot import colormaps - - if color_amplitude: - amps = peak_amplitudes + raise ValueError("segment_index must be an int or a list of ints") + + # Validate all segment indices exist in the data + for idx in segment_indices: + if idx not in unique_segments: + raise ValueError(f"segment_index {idx} not found in peaks data") + + # Filter data for the selected segments + # Note: For simplicity, we'll filter all data first, then construct dict of dicts + segment_mask = np.isin(peaks["segment_index"], segment_indices) + filtered_peaks = peaks[segment_mask] + filtered_locations = peak_locations[segment_mask] + if peak_amplitudes is not None: + filtered_amplitudes = peak_amplitudes[segment_mask] + + # Create dict of dicts structure for the base class + spike_train_data = {} + y_axis_data = {} + + # Process each segment separately + for seg_idx in segment_indices: + segment_mask = filtered_peaks["segment_index"] == seg_idx + segment_peaks = filtered_peaks[segment_mask] + segment_locations = filtered_locations[segment_mask] + + # Convert peak times to seconds + spike_times = segment_peaks["sample_index"] / sampling_frequency + + # Store in dict of dicts format (using 0 as the "unit" id) + spike_train_data[seg_idx] = {0: spike_times} + y_axis_data[seg_idx] = {0: segment_locations[direction]} + + if color_amplitude and peak_amplitudes is not None: + amps = filtered_amplitudes amps_abs = np.abs(amps) q_95 = np.quantile(amps_abs, 0.95) - cmap = colormaps[cmap] + cmap_obj = colormaps[cmap] if clim is None: amps = amps_abs amps /= q_95 - c = cmap(amps) + c = cmap_obj(amps) else: - norm_function = Normalize(vmin=dp.clim[0], vmax=dp.clim[1], clip=True) - c = cmap(norm_function(amps)) + norm_function = Normalize(vmin=clim[0], vmax=clim[1], clip=True) + c = cmap_obj(norm_function(amps)) color_kwargs = dict( color=None, c=c, @@ -223,19 +258,33 @@ def __init__( ) else: color_kwargs = dict(color=color, c=None, alpha=alpha) - - # convert data into format that `BaseRasterWidget` can take it in - spike_train_data = {0: peaks["sample_index"] / sampling_frequency} - y_axis_data = {0: peak_locations[direction]} - + + # Calculate total duration for x-axis limits + total_duration = 0 + for seg_idx in segment_indices: + if recording is not None and hasattr(recording, "get_duration"): + duration = recording.get_duration(seg_idx) + else: + # Estimate from spike times + segment_mask = filtered_peaks["segment_index"] == seg_idx + segment_peaks = filtered_peaks[segment_mask] + if len(segment_peaks) > 0: + max_sample = np.max(segment_peaks["sample_index"]) + duration = (max_sample + 1) / sampling_frequency + else: + duration = 0 + total_duration += duration + plot_data = dict( spike_train_data=spike_train_data, y_axis_data=y_axis_data, + segment_index=segment_indices, y_lim=depth_lim, color_kwargs=color_kwargs, scatter_decimate=scatter_decimate, title="Peak depth", y_label="Depth [um]", + total_duration=total_duration, ) BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs) diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 398ae4d728..42b5876da4 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -15,12 +15,17 @@ class BaseRasterWidget(BaseWidget): Parameters ---------- - spike_train_data : dict - A dict of spike trains, indexed by the unit_id - y_axis_data : dict - A dict of the y-axis data, indexed by the unit_id + spike_train_data : dict of dicts + A dict of dicts where the structure is spike_train_data[segment_index][unit_id]. + y_axis_data : dict of dicts + A dict of dicts where the structure is y_axis_data[segment_index][unit_id]. + For backwards compatibility, a flat dict indexed by unit_id will be internally + converted to a dict of dicts with segment 0. unit_ids : array-like | None, default: None List of unit_ids to plot + segment_index : int | list | None, default: None + For multi-segment data, specifies which segment(s) to plot. If None, uses all available segments. + For single-segment data, this parameter is ignored. total_duration : int | None, default: None Duration of spike_train_data in seconds. plot_histograms : bool, default: False @@ -48,6 +53,8 @@ class BaseRasterWidget(BaseWidget): Ticks on y-axis, passed to `set_yticks`. If None, default ticks are used. hide_unit_selector : bool, default: False For sortingview backend, if True the unit selector is not displayed + segment_boundary_kwargs : dict | None, default: None + Additional arguments for the segment boundary lines, passed to `matplotlib.axvline` backend : str | None, default None Which plotting backend to use e.g. 'matplotlib', 'ipywidgets'. If None, uses default from `get_default_plotter_backend`. @@ -58,6 +65,7 @@ def __init__( spike_train_data: dict, y_axis_data: dict, unit_ids: list | None = None, + segment_index: int | list | None = None, total_duration: int | None = None, plot_histograms: bool = False, bins: int | None = None, @@ -71,13 +79,103 @@ def __init__( y_label: str | None = None, y_ticks: bool = False, hide_unit_selector: bool = True, + segment_boundary_kwargs: dict | None = None, backend: str | None = None, **backend_kwargs, ): + # Set default segment boundary kwargs if not provided + if segment_boundary_kwargs is None: + segment_boundary_kwargs = {"color": "gray", "linestyle": "--", "alpha": 0.7} + + # Process the data + available_segments = list(spike_train_data.keys()) + available_segments.sort() # Ensure consistent ordering + + # Determine which segments to use + if segment_index is None: + # Use all segments by default + segments_to_use = available_segments + elif isinstance(segment_index, int): + # Single segment specified + if segment_index not in available_segments: + raise ValueError(f"segment_index {segment_index} not found in data") + segments_to_use = [segment_index] + elif isinstance(segment_index, list): + # Multiple segments specified + for idx in segment_index: + if idx not in available_segments: + raise ValueError(f"segment_index {idx} not found in data") + segments_to_use = segment_index + else: + raise ValueError("segment_index must be int, list, or None") + + # Get all unit IDs present in any segment if not specified + if unit_ids is None: + all_units = set() + for seg_idx in segments_to_use: + all_units.update(spike_train_data[seg_idx].keys()) + unit_ids = list(all_units) + + # Calculate segment durations and boundaries + segment_durations = [] + for seg_idx in segments_to_use: + max_time = 0 + for unit_id in unit_ids: + if unit_id in spike_train_data[seg_idx]: + unit_times = spike_train_data[seg_idx][unit_id] + if len(unit_times) > 0: + max_time = max(max_time, np.max(unit_times)) + segment_durations.append(max_time) + + # Calculate cumulative durations for segment boundaries + cumulative_durations = [0] + for duration in segment_durations[:-1]: + cumulative_durations.append(cumulative_durations[-1] + duration) + + # Segment boundaries for visualization (only internal boundaries) + segment_boundaries = cumulative_durations[1:] if len(segments_to_use) > 1 else None + + # Concatenate data across segments with proper time offsets + concatenated_spike_trains = {unit_id: [] for unit_id in unit_ids} + concatenated_y_axis = {unit_id: [] for unit_id in unit_ids} + + for i, seg_idx in enumerate(segments_to_use): + offset = cumulative_durations[i] + + for unit_id in unit_ids: + if unit_id in spike_train_data[seg_idx]: + # Get spike times for this unit in this segment + spike_times = spike_train_data[seg_idx][unit_id] + + # Adjust spike times by adding cumulative duration of previous segments + if offset > 0: + adjusted_times = spike_times + offset + else: + adjusted_times = spike_times + + # Get y-axis data for this unit in this segment + y_values = y_axis_data[seg_idx][unit_id] + + # Concatenate with any existing data + if len(concatenated_spike_trains[unit_id]) > 0: + concatenated_spike_trains[unit_id] = np.concatenate([concatenated_spike_trains[unit_id], adjusted_times]) + concatenated_y_axis[unit_id] = np.concatenate([concatenated_y_axis[unit_id], y_values]) + else: + concatenated_spike_trains[unit_id] = adjusted_times + concatenated_y_axis[unit_id] = y_values + + # Update spike train and y-axis data with concatenated values + processed_spike_train_data = concatenated_spike_trains + processed_y_axis_data = concatenated_y_axis + + # Calculate total duration from the data if not provided + if total_duration is None: + total_duration = cumulative_durations[-1] + segment_durations[-1] + plot_data = dict( - spike_train_data=spike_train_data, - y_axis_data=y_axis_data, + spike_train_data=processed_spike_train_data, + y_axis_data=processed_y_axis_data, unit_ids=unit_ids, plot_histograms=plot_histograms, y_lim=y_lim, @@ -92,6 +190,8 @@ def __init__( bins=bins, y_ticks=y_ticks, hide_unit_selector=hide_unit_selector, + segment_boundaries=segment_boundaries, + segment_boundary_kwargs=segment_boundary_kwargs, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -134,7 +234,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): y_axis_data = dp.y_axis_data for unit_id in unit_ids: - + if unit_id not in spike_train_data: + continue # Skip this unit if not in data + unit_spike_train = spike_train_data[unit_id][:: dp.scatter_decimate] unit_y_data = y_axis_data[unit_id][:: dp.scatter_decimate] @@ -155,6 +257,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): count, bins = np.histogram(unit_y_data, bins=bins) ax_hist.plot(count, bins[:-1], color=unit_colors[unit_id], alpha=0.8) + # Add segment boundary lines if provided + if getattr(dp, 'segment_boundaries', None) is not None: + for boundary in dp.segment_boundaries: + scatter_ax.axvline(boundary, **dp.segment_boundary_kwargs) + if dp.plot_histograms: ax_hist = self.axes.flatten()[1] ax_hist.set_ylim(scatter_ax.get_ylim()) @@ -282,8 +389,9 @@ class RasterWidget(BaseRasterWidget): A sorting object sorting_analyzer : SortingAnalyzer | None, default: None A sorting analyzer object - segment_index : None or int - The segment index. + segment_index : int or list of int or None, default: None + The segment index or indices to use. If None and there are multiple segments, defaults to 0. + If a list of indices is provided, spike trains are concatenated across the specified segments. unit_ids : list List of unit ids time_range : list @@ -303,39 +411,88 @@ def __init__( backend=None, **backend_kwargs, ): + recording = None if sorting is None and sorting_analyzer is None: raise Exception("Must supply either a sorting or a sorting_analyzer") elif sorting is not None and sorting_analyzer is not None: raise Exception("Should supply either a sorting or a sorting_analyzer, not both") elif sorting_analyzer is not None: sorting = sorting_analyzer.sorting + recording = sorting_analyzer.recording sorting = self.ensure_sorting(sorting) - if sorting.get_num_segments() > 1: + num_segments = sorting.get_num_segments() + + # Handle segment_index input + if num_segments > 1: if segment_index is None: warn("More than one segment available! Using `segment_index = 0`.") segment_index = 0 else: segment_index = 0 + + # Convert segment_index to list for consistent processing + if isinstance(segment_index, int): + segment_indices = [segment_index] + elif isinstance(segment_index, list): + segment_indices = segment_index + else: + raise ValueError("segment_index must be an int or a list of ints") + + # Validate segment indices + for idx in segment_indices: + if not isinstance(idx, int): + raise ValueError(f"Each segment index must be an integer, got {type(idx)}") + if idx < 0 or idx >= num_segments: + raise ValueError(f"segment_index {idx} out of range (0 to {num_segments - 1})") if unit_ids is None: unit_ids = sorting.unit_ids - all_spiketrains = { - unit_id: sorting.get_unit_spike_train(unit_id, segment_index=segment_index, return_times=True) - for unit_id in unit_ids - } - + # Create dict of dicts structure + spike_train_data = {} + y_axis_data = {} + + # Create a lookup dictionary for unit indices + unit_indices_map = {unit_id: i for i, unit_id in enumerate(unit_ids)} + + # Calculate total duration across all segments + total_duration = 0 + for seg_idx in segment_indices: + # Try to get duration from recording if available + if recording is not None: + duration = recording.get_duration(seg_idx) + else: + # Fallback: estimate from max spike time + max_time = 0 + for unit_id in unit_ids: + st = sorting.get_unit_spike_train(unit_id, segment_index=seg_idx, return_times=True) + if len(st) > 0: + max_time = max(max_time, np.max(st)) + duration = max_time + + total_duration += duration + + # Initialize dicts for this segment + spike_train_data[seg_idx] = {} + y_axis_data[seg_idx] = {} + + # Get spike trains for each unit in this segment + for unit_id in unit_ids: + spike_times = sorting.get_unit_spike_train(unit_id, segment_index=seg_idx, return_times=True) + + # Store spike trains + spike_train_data[seg_idx][unit_id] = spike_times + + # Create raster locations (y-values for plotting) + unit_index = unit_indices_map[unit_id] + y_axis_data[seg_idx][unit_id] = unit_index * np.ones(len(spike_times)) + + # Apply time range filtering if specified if time_range is not None: assert len(time_range) == 2, "'time_range' should be a list with start and end time in seconds" - for unit_id in unit_ids: - unit_st = all_spiketrains[unit_id] - all_spiketrains[unit_id] = unit_st[(time_range[0] < unit_st) & (unit_st < time_range[1])] - - raster_locations = { - unit_id: unit_index * np.ones(len(all_spiketrains[unit_id])) for unit_index, unit_id in enumerate(unit_ids) - } + # Let BaseRasterWidget handle the filtering unit_indices = list(range(len(unit_ids))) @@ -346,14 +503,16 @@ def __init__( y_ticks = {"ticks": unit_indices, "labels": unit_ids} plot_data = dict( - spike_train_data=all_spiketrains, - y_axis_data=raster_locations, + spike_train_data=spike_train_data, + y_axis_data=y_axis_data, + segment_index=segment_indices, x_lim=time_range, y_label="Unit id", unit_ids=unit_ids, unit_colors=unit_colors, plot_histograms=None, y_ticks=y_ticks, + total_duration=total_duration, ) - BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs) + BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs) \ No newline at end of file From 11a1845c61b456f5071c203913ec1db3cc3cd1c0 Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Tue, 25 Mar 2025 17:37:19 +0000 Subject: [PATCH 03/17] Retain sortingview compatibility --- src/spikeinterface/widgets/amplitudes.py | 27 +++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 3b99f4dd50..063641e40f 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -84,6 +84,14 @@ def __init__( else: segment_index = 0 + # Check for SortingView backend + is_sortingview = backend == "sortingview" + + # For SortingView, ensure we're only using a single segment + if is_sortingview and isinstance(segment_index, list) and len(segment_index) > 1: + warn("SortingView backend currently supports only single segment. Using first segment.") + segment_index = segment_index[0] + # Convert segment_index to list for consistent processing if isinstance(segment_index, int): segment_indices = [segment_index] @@ -144,12 +152,10 @@ def __init__( for idx in segment_indices: duration = sorting_analyzer.get_num_samples(idx) / sorting_analyzer.sampling_frequency total_duration += duration - + + # Build the plot data with the full dict of dicts structure plot_data = dict( - spike_train_data=spiketrains_by_segment, - y_axis_data=amplitudes_by_segment, unit_colors=unit_colors, - segment_index=segment_indices, plot_histograms=plot_histograms, bins=bins, total_duration=total_duration, @@ -160,7 +166,18 @@ def __init__( y_lim=y_lim, scatter_decimate=scatter_decimate, ) - + + # If using SortingView, extract just the first segment's data as flat dicts + if is_sortingview: + first_segment = segment_indices[0] + plot_data["spike_train_data"] = spiketrains_by_segment[first_segment] + plot_data["y_axis_data"] = amplitudes_by_segment[first_segment] + else: + # Otherwise use the full dict of dicts structure with all segments + plot_data["spike_train_data"] = spiketrains_by_segment + plot_data["y_axis_data"] = amplitudes_by_segment + plot_data["segment_index"] = segment_indices + BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs) def plot_sortingview(self, data_plot, **backend_kwargs): From 16e6272f39359eba4df9556082f46ecb04ce8283 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Mar 2025 17:39:53 +0000 Subject: [PATCH 04/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/amplitudes.py | 30 ++++++------ src/spikeinterface/widgets/motion.py | 26 +++++----- src/spikeinterface/widgets/rasters.py | 60 ++++++++++++------------ 3 files changed, 59 insertions(+), 57 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 063641e40f..ef85f9ca30 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -75,7 +75,7 @@ def __init__( unit_ids = sorting.unit_ids num_segments = sorting.get_num_segments() - + # Handle segment_index input if num_segments > 1: if segment_index is None: @@ -83,15 +83,15 @@ def __init__( segment_index = 0 else: segment_index = 0 - + # Check for SortingView backend is_sortingview = backend == "sortingview" - + # For SortingView, ensure we're only using a single segment if is_sortingview and isinstance(segment_index, list) and len(segment_index) > 1: warn("SortingView backend currently supports only single segment. Using first segment.") segment_index = segment_index[0] - + # Convert segment_index to list for consistent processing if isinstance(segment_index, int): segment_indices = [segment_index] @@ -99,7 +99,7 @@ def __init__( segment_indices = segment_index else: raise ValueError("segment_index must be an int or a list of ints") - + # Validate segment indices for idx in segment_indices: if not isinstance(idx, int): @@ -110,23 +110,23 @@ def __init__( # Create multi-segment data structure (dict of dicts) spiketrains_by_segment = {} amplitudes_by_segment = {} - + for idx in segment_indices: amplitudes_segment = amplitudes[idx] - + # Initialize for this segment spiketrains_by_segment[idx] = {} amplitudes_by_segment[idx] = {} - + for unit_id in unit_ids: # Get spike times for this unit in this segment spike_times = sorting.get_unit_spike_train(unit_id, segment_index=idx, return_times=True) amps = amplitudes_segment[unit_id] - + # Store data in dict of dicts format spiketrains_by_segment[idx][unit_id] = spike_times amplitudes_by_segment[idx][unit_id] = amps - + # Apply max_spikes_per_unit limit if specified if max_spikes_per_unit is not None: for idx in segment_indices: @@ -138,7 +138,7 @@ def __init__( # to ensure we have max_spikes_per_unit total after concatenation segment_count = len(segment_indices) segment_max = max(1, max_spikes_per_unit // segment_count) - + if len(st) > segment_max: random_idxs = np.random.choice(len(st), size=segment_max, replace=False) spiketrains_by_segment[idx][unit_id] = st[random_idxs] @@ -152,7 +152,7 @@ def __init__( for idx in segment_indices: duration = sorting_analyzer.get_num_samples(idx) / sorting_analyzer.sampling_frequency total_duration += duration - + # Build the plot data with the full dict of dicts structure plot_data = dict( unit_colors=unit_colors, @@ -166,7 +166,7 @@ def __init__( y_lim=y_lim, scatter_decimate=scatter_decimate, ) - + # If using SortingView, extract just the first segment's data as flat dicts if is_sortingview: first_segment = segment_indices[0] @@ -177,9 +177,9 @@ def __init__( plot_data["spike_train_data"] = spiketrains_by_segment plot_data["y_axis_data"] = amplitudes_by_segment plot_data["segment_index"] = segment_indices - + BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs) - + def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 564373b704..1a1512545b 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -162,9 +162,9 @@ def __init__( ): from matplotlib.pyplot import colormaps from matplotlib.colors import Normalize - + assert peaks is not None or sorting_analyzer is not None - + if peaks is not None: assert peak_locations is not None if recording is None: @@ -172,7 +172,7 @@ def __init__( else: sampling_frequency = recording.sampling_frequency peak_amplitudes = peaks["amplitude"] - + if sorting_analyzer is not None: if sorting_analyzer.has_recording(): recording = sorting_analyzer.recording @@ -196,7 +196,7 @@ def __init__( peak_amplitudes = None unique_segments = np.unique(peaks["segment_index"]) - + if segment_index is None: if len(unique_segments) == 1: segment_indices = [int(unique_segments[0])] @@ -208,12 +208,12 @@ def __init__( segment_indices = segment_index else: raise ValueError("segment_index must be an int or a list of ints") - + # Validate all segment indices exist in the data for idx in segment_indices: if idx not in unique_segments: raise ValueError(f"segment_index {idx} not found in peaks data") - + # Filter data for the selected segments # Note: For simplicity, we'll filter all data first, then construct dict of dicts segment_mask = np.isin(peaks["segment_index"], segment_indices) @@ -221,24 +221,24 @@ def __init__( filtered_locations = peak_locations[segment_mask] if peak_amplitudes is not None: filtered_amplitudes = peak_amplitudes[segment_mask] - + # Create dict of dicts structure for the base class spike_train_data = {} y_axis_data = {} - + # Process each segment separately for seg_idx in segment_indices: segment_mask = filtered_peaks["segment_index"] == seg_idx segment_peaks = filtered_peaks[segment_mask] segment_locations = filtered_locations[segment_mask] - + # Convert peak times to seconds spike_times = segment_peaks["sample_index"] / sampling_frequency - + # Store in dict of dicts format (using 0 as the "unit" id) spike_train_data[seg_idx] = {0: spike_times} y_axis_data[seg_idx] = {0: segment_locations[direction]} - + if color_amplitude and peak_amplitudes is not None: amps = filtered_amplitudes amps_abs = np.abs(amps) @@ -258,7 +258,7 @@ def __init__( ) else: color_kwargs = dict(color=color, c=None, alpha=alpha) - + # Calculate total duration for x-axis limits total_duration = 0 for seg_idx in segment_indices: @@ -274,7 +274,7 @@ def __init__( else: duration = 0 total_duration += duration - + plot_data = dict( spike_train_data=spike_train_data, y_axis_data=y_axis_data, diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 42b5876da4..3d4470c249 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -87,11 +87,11 @@ def __init__( # Set default segment boundary kwargs if not provided if segment_boundary_kwargs is None: segment_boundary_kwargs = {"color": "gray", "linestyle": "--", "alpha": 0.7} - + # Process the data available_segments = list(spike_train_data.keys()) available_segments.sort() # Ensure consistent ordering - + # Determine which segments to use if segment_index is None: # Use all segments by default @@ -109,14 +109,14 @@ def __init__( segments_to_use = segment_index else: raise ValueError("segment_index must be int, list, or None") - + # Get all unit IDs present in any segment if not specified if unit_ids is None: all_units = set() for seg_idx in segments_to_use: all_units.update(spike_train_data[seg_idx].keys()) unit_ids = list(all_units) - + # Calculate segment durations and boundaries segment_durations = [] for seg_idx in segments_to_use: @@ -127,52 +127,54 @@ def __init__( if len(unit_times) > 0: max_time = max(max_time, np.max(unit_times)) segment_durations.append(max_time) - + # Calculate cumulative durations for segment boundaries cumulative_durations = [0] for duration in segment_durations[:-1]: cumulative_durations.append(cumulative_durations[-1] + duration) - + # Segment boundaries for visualization (only internal boundaries) segment_boundaries = cumulative_durations[1:] if len(segments_to_use) > 1 else None - + # Concatenate data across segments with proper time offsets concatenated_spike_trains = {unit_id: [] for unit_id in unit_ids} concatenated_y_axis = {unit_id: [] for unit_id in unit_ids} - + for i, seg_idx in enumerate(segments_to_use): offset = cumulative_durations[i] - + for unit_id in unit_ids: if unit_id in spike_train_data[seg_idx]: # Get spike times for this unit in this segment spike_times = spike_train_data[seg_idx][unit_id] - + # Adjust spike times by adding cumulative duration of previous segments if offset > 0: adjusted_times = spike_times + offset else: adjusted_times = spike_times - + # Get y-axis data for this unit in this segment y_values = y_axis_data[seg_idx][unit_id] - + # Concatenate with any existing data if len(concatenated_spike_trains[unit_id]) > 0: - concatenated_spike_trains[unit_id] = np.concatenate([concatenated_spike_trains[unit_id], adjusted_times]) + concatenated_spike_trains[unit_id] = np.concatenate( + [concatenated_spike_trains[unit_id], adjusted_times] + ) concatenated_y_axis[unit_id] = np.concatenate([concatenated_y_axis[unit_id], y_values]) else: concatenated_spike_trains[unit_id] = adjusted_times concatenated_y_axis[unit_id] = y_values - + # Update spike train and y-axis data with concatenated values processed_spike_train_data = concatenated_spike_trains processed_y_axis_data = concatenated_y_axis - + # Calculate total duration from the data if not provided if total_duration is None: total_duration = cumulative_durations[-1] + segment_durations[-1] - + plot_data = dict( spike_train_data=processed_spike_train_data, y_axis_data=processed_y_axis_data, @@ -236,7 +238,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): for unit_id in unit_ids: if unit_id not in spike_train_data: continue # Skip this unit if not in data - + unit_spike_train = spike_train_data[unit_id][:: dp.scatter_decimate] unit_y_data = y_axis_data[unit_id][:: dp.scatter_decimate] @@ -258,7 +260,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax_hist.plot(count, bins[:-1], color=unit_colors[unit_id], alpha=0.8) # Add segment boundary lines if provided - if getattr(dp, 'segment_boundaries', None) is not None: + if getattr(dp, "segment_boundaries", None) is not None: for boundary in dp.segment_boundaries: scatter_ax.axvline(boundary, **dp.segment_boundary_kwargs) @@ -423,7 +425,7 @@ def __init__( sorting = self.ensure_sorting(sorting) num_segments = sorting.get_num_segments() - + # Handle segment_index input if num_segments > 1: if segment_index is None: @@ -431,7 +433,7 @@ def __init__( segment_index = 0 else: segment_index = 0 - + # Convert segment_index to list for consistent processing if isinstance(segment_index, int): segment_indices = [segment_index] @@ -439,7 +441,7 @@ def __init__( segment_indices = segment_index else: raise ValueError("segment_index must be an int or a list of ints") - + # Validate segment indices for idx in segment_indices: if not isinstance(idx, int): @@ -453,10 +455,10 @@ def __init__( # Create dict of dicts structure spike_train_data = {} y_axis_data = {} - + # Create a lookup dictionary for unit indices unit_indices_map = {unit_id: i for i, unit_id in enumerate(unit_ids)} - + # Calculate total duration across all segments total_duration = 0 for seg_idx in segment_indices: @@ -471,20 +473,20 @@ def __init__( if len(st) > 0: max_time = max(max_time, np.max(st)) duration = max_time - + total_duration += duration - + # Initialize dicts for this segment spike_train_data[seg_idx] = {} y_axis_data[seg_idx] = {} - + # Get spike trains for each unit in this segment for unit_id in unit_ids: spike_times = sorting.get_unit_spike_train(unit_id, segment_index=seg_idx, return_times=True) - + # Store spike trains spike_train_data[seg_idx][unit_id] = spike_times - + # Create raster locations (y-values for plotting) unit_index = unit_indices_map[unit_id] y_axis_data[seg_idx][unit_id] = unit_index * np.ones(len(spike_times)) @@ -515,4 +517,4 @@ def __init__( total_duration=total_duration, ) - BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs) \ No newline at end of file + BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs) From cbc790ce0f46a6680f73d32ae161db6b43f54f15 Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Tue, 29 Apr 2025 13:24:22 +0100 Subject: [PATCH 05/17] Improve segment validation to list only. Add unitls function for validation --- src/spikeinterface/widgets/amplitudes.py | 34 +++------------- src/spikeinterface/widgets/motion.py | 13 +++--- src/spikeinterface/widgets/rasters.py | 52 ++++++------------------ src/spikeinterface/widgets/utils.py | 44 ++++++++++++++++++++ 4 files changed, 67 insertions(+), 76 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index ef85f9ca30..a6fb9948ae 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -5,7 +5,7 @@ from .rasters import BaseRasterWidget from .base import BaseWidget, to_attr -from .utils import get_some_colors +from .utils import get_some_colors, validate_segment_indices from spikeinterface.core.sortinganalyzer import SortingAnalyzer @@ -25,7 +25,7 @@ class AmplitudesWidget(BaseRasterWidget): unit_colors : dict | None, default: None Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted by matplotlib. If None, default colors are chosen using the `get_some_colors` function. - segment_index : int or list of int or None, default: None + segment_indices : list of int or None, default: None Segment index or indices to plot. If None and there are multiple segments, defaults to 0. If list, spike trains and amplitudes are concatenated across the specified segments. max_spikes_per_unit : int or None, default: None @@ -52,7 +52,7 @@ def __init__( sorting_analyzer: SortingAnalyzer, unit_ids=None, unit_colors=None, - segment_index=None, + segment_indices=None, max_spikes_per_unit=None, y_lim=None, scatter_decimate=1, @@ -74,38 +74,16 @@ def __init__( if unit_ids is None: unit_ids = sorting.unit_ids - num_segments = sorting.get_num_segments() - # Handle segment_index input - if num_segments > 1: - if segment_index is None: - warn("More than one segment available! Using `segment_index = 0`.") - segment_index = 0 - else: - segment_index = 0 + segment_indices = validate_segment_indices(segment_indices, sorting) # Check for SortingView backend is_sortingview = backend == "sortingview" # For SortingView, ensure we're only using a single segment - if is_sortingview and isinstance(segment_index, list) and len(segment_index) > 1: + if is_sortingview and len(segment_indices) > 1: warn("SortingView backend currently supports only single segment. Using first segment.") - segment_index = segment_index[0] - - # Convert segment_index to list for consistent processing - if isinstance(segment_index, int): - segment_indices = [segment_index] - elif isinstance(segment_index, list): - segment_indices = segment_index - else: - raise ValueError("segment_index must be an int or a list of ints") - - # Validate segment indices - for idx in segment_indices: - if not isinstance(idx, int): - raise ValueError(f"Each segment index must be an integer, got {type(idx)}") - if idx < 0 or idx >= num_segments: - raise ValueError(f"segment_index {idx} out of range (0 to {num_segments - 1})") + segment_indices = segment_indices[0] # Create multi-segment data structure (dict of dicts) spiketrains_by_segment = {} diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 1a1512545b..f932ed44f6 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -121,7 +121,7 @@ class DriftRasterMapWidget(BaseRasterWidget): The recording extractor object (only used to get "real" times). sampling_frequency : float, default: None The sampling frequency (needed if recording is None). - segment_index : int or list of int or None, default: None + segment_indices : list of int or None, default: None The segment index or indices to display. If None and there's only one segment, it's used. If None and there are multiple segments, you must specify which to use. If a list of indices is provided, peaks and locations are concatenated across the segments. @@ -149,7 +149,7 @@ def __init__( direction: str = "y", recording: BaseRecording | None = None, sampling_frequency: float | None = None, - segment_index: int | list | None = None, + segment_indices: list[int] | None = None, depth_lim: tuple[float, float] | None = None, color_amplitude: bool = True, scatter_decimate: int | None = None, @@ -197,16 +197,13 @@ def __init__( unique_segments = np.unique(peaks["segment_index"]) - if segment_index is None: + if segment_indices is None: if len(unique_segments) == 1: segment_indices = [int(unique_segments[0])] else: raise ValueError("segment_index must be specified if there are multiple segments") - elif isinstance(segment_index, int): - segment_indices = [segment_index] - elif isinstance(segment_index, list): - segment_indices = segment_index - else: + + if not isinstance(segment_indices, list): raise ValueError("segment_index must be an int or a list of ints") # Validate all segment indices exist in the data diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 3d4470c249..f03f980f24 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -4,7 +4,7 @@ from warnings import warn from .base import BaseWidget, to_attr, default_backend_kwargs -from .utils import get_some_colors +from .utils import get_some_colors, validate_segment_indices class BaseRasterWidget(BaseWidget): @@ -23,7 +23,7 @@ class BaseRasterWidget(BaseWidget): converted to a dict of dicts with segment 0. unit_ids : array-like | None, default: None List of unit_ids to plot - segment_index : int | list | None, default: None + segment_indices : list | None, default: None For multi-segment data, specifies which segment(s) to plot. If None, uses all available segments. For single-segment data, this parameter is ignored. total_duration : int | None, default: None @@ -65,7 +65,7 @@ def __init__( spike_train_data: dict, y_axis_data: dict, unit_ids: list | None = None, - segment_index: int | list | None = None, + segment_indices: list | None = None, total_duration: int | None = None, plot_histograms: bool = False, bins: int | None = None, @@ -93,22 +93,17 @@ def __init__( available_segments.sort() # Ensure consistent ordering # Determine which segments to use - if segment_index is None: + if segment_indices is None: # Use all segments by default segments_to_use = available_segments - elif isinstance(segment_index, int): - # Single segment specified - if segment_index not in available_segments: - raise ValueError(f"segment_index {segment_index} not found in data") - segments_to_use = [segment_index] - elif isinstance(segment_index, list): + elif isinstance(segment_indices, list): # Multiple segments specified - for idx in segment_index: + for idx in segment_indices: if idx not in available_segments: - raise ValueError(f"segment_index {idx} not found in data") - segments_to_use = segment_index + raise ValueError(f"segment_index {idx} not found in avialable segments {available_segments}") + segments_to_use = segment_indices else: - raise ValueError("segment_index must be int, list, or None") + raise ValueError("segment_index must be `list` or `None`") # Get all unit IDs present in any segment if not specified if unit_ids is None: @@ -391,7 +386,7 @@ class RasterWidget(BaseRasterWidget): A sorting object sorting_analyzer : SortingAnalyzer | None, default: None A sorting analyzer object - segment_index : int or list of int or None, default: None + segment_indices : list of int or None, default: None The segment index or indices to use. If None and there are multiple segments, defaults to 0. If a list of indices is provided, spike trains are concatenated across the specified segments. unit_ids : list @@ -406,7 +401,7 @@ def __init__( self, sorting=None, sorting_analyzer=None, - segment_index=None, + segment_indices=None, unit_ids=None, time_range=None, color="k", @@ -424,30 +419,7 @@ def __init__( sorting = self.ensure_sorting(sorting) - num_segments = sorting.get_num_segments() - - # Handle segment_index input - if num_segments > 1: - if segment_index is None: - warn("More than one segment available! Using `segment_index = 0`.") - segment_index = 0 - else: - segment_index = 0 - - # Convert segment_index to list for consistent processing - if isinstance(segment_index, int): - segment_indices = [segment_index] - elif isinstance(segment_index, list): - segment_indices = segment_index - else: - raise ValueError("segment_index must be an int or a list of ints") - - # Validate segment indices - for idx in segment_indices: - if not isinstance(idx, int): - raise ValueError(f"Each segment index must be an integer, got {type(idx)}") - if idx < 0 or idx >= num_segments: - raise ValueError(f"segment_index {idx} out of range (0 to {num_segments - 1})") + segment_indices = validate_segment_indices(sorting, segment_indices) if unit_ids is None: unit_ids = sorting.unit_ids diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index a1ac9d4af9..dd9bb20065 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -3,6 +3,7 @@ from warnings import warn import numpy as np +from spikeinterface.core import BaseSorting def get_some_colors( keys, @@ -349,3 +350,46 @@ def make_units_table_from_analyzer( ) return units_table + +def validate_segment_indices(segment_indices: list[int] | None, sorting: BaseSorting): + """ + Validate a list of segment indices for a sorting object. + + Parameters + ---------- + segment_indices : list of int + The segment index or indices to validate. + sorting : BaseSorting + The sorting object to validate against. + + Returns + ------- + list of int + A list of valid segment indices. + + Raises + ------ + ValueError + If the segment indices are not valid. + """ + num_segments = sorting.get_num_segments() + + # Handle segment_indices input + if segment_indices is None: + if num_segments > 1: + warn("Segment indices not specified. Using first available segment only.") + return [0] + + # Convert segment_index to list for consistent processing + if not isinstance(segment_indices, list): + raise ValueError("segment_indices must be a list of ints - available segments are: " + list(range(num_segments))) + + # Validate segment indices + for idx in segment_indices: + if not isinstance(idx, int): + raise ValueError(f"Each segment index must be an integer, got {type(idx)}") + if idx < 0 or idx >= num_segments: + raise ValueError(f"segment_index {idx} out of range (0 to {num_segments - 1})") + + + return segment_indices \ No newline at end of file From 540db00a14517abbea6f59a383aa5f4f0a58e35c Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Wed, 30 Apr 2025 11:15:10 +0100 Subject: [PATCH 06/17] minor fixes --- src/spikeinterface/widgets/amplitudes.py | 6 ++++-- src/spikeinterface/widgets/motion.py | 8 ++++---- src/spikeinterface/widgets/rasters.py | 4 ++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index a6fb9948ae..c09f3d82be 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -83,7 +83,7 @@ def __init__( # For SortingView, ensure we're only using a single segment if is_sortingview and len(segment_indices) > 1: warn("SortingView backend currently supports only single segment. Using first segment.") - segment_indices = segment_indices[0] + segment_indices = [segment_indices[0]] # Create multi-segment data structure (dict of dicts) spiketrains_by_segment = {} @@ -150,11 +150,13 @@ def __init__( first_segment = segment_indices[0] plot_data["spike_train_data"] = spiketrains_by_segment[first_segment] plot_data["y_axis_data"] = amplitudes_by_segment[first_segment] + print(plot_data["spike_train_data"]) + print(plot_data["y_axis_data"]) else: # Otherwise use the full dict of dicts structure with all segments plot_data["spike_train_data"] = spiketrains_by_segment plot_data["y_axis_data"] = amplitudes_by_segment - plot_data["segment_index"] = segment_indices + plot_data["segment_indices"] = segment_indices BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index f932ed44f6..9071698b2f 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -201,10 +201,10 @@ def __init__( if len(unique_segments) == 1: segment_indices = [int(unique_segments[0])] else: - raise ValueError("segment_index must be specified if there are multiple segments") + raise ValueError("segment_indices must be specified if there are multiple segments") if not isinstance(segment_indices, list): - raise ValueError("segment_index must be an int or a list of ints") + raise ValueError("segment_indices must be a list of ints") # Validate all segment indices exist in the data for idx in segment_indices: @@ -275,7 +275,7 @@ def __init__( plot_data = dict( spike_train_data=spike_train_data, y_axis_data=y_axis_data, - segment_index=segment_indices, + segment_indices=segment_indices, y_lim=depth_lim, color_kwargs=color_kwargs, scatter_decimate=scatter_decimate, @@ -417,7 +417,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): commpon_drift_map_kwargs = dict( direction=dp.motion.direction, recording=dp.recording, - segment_index=dp.segment_index, + segment_indices=list(dp.segment_index), depth_lim=dp.depth_lim, scatter_decimate=dp.scatter_decimate, color_amplitude=dp.color_amplitude, diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index f03f980f24..6012cdaed2 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -419,7 +419,7 @@ def __init__( sorting = self.ensure_sorting(sorting) - segment_indices = validate_segment_indices(sorting, segment_indices) + segment_indices = validate_segment_indices(segment_indices, sorting) if unit_ids is None: unit_ids = sorting.unit_ids @@ -479,7 +479,7 @@ def __init__( plot_data = dict( spike_train_data=spike_train_data, y_axis_data=y_axis_data, - segment_index=segment_indices, + segment_indices=segment_indices, x_lim=time_range, y_label="Unit id", unit_ids=unit_ids, From afdea786de73f7a068e85e27e6af8577737cd4b8 Mon Sep 17 00:00:00 2001 From: jakeswann1 Date: Wed, 30 Apr 2025 11:50:10 +0100 Subject: [PATCH 07/17] Update durations to use a list --- src/spikeinterface/widgets/amplitudes.py | 12 ++- src/spikeinterface/widgets/motion.py | 16 ++-- src/spikeinterface/widgets/rasters.py | 102 +++++++++-------------- 3 files changed, 51 insertions(+), 79 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index c09f3d82be..5c0f304bba 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -125,18 +125,18 @@ def __init__( if plot_histograms and bins is None: bins = 100 - # Calculate total duration across all segments for x-axis limits - total_duration = 0 + # Calculate durations for all segments for x-axis limits + durations = [] for idx in segment_indices: duration = sorting_analyzer.get_num_samples(idx) / sorting_analyzer.sampling_frequency - total_duration += duration + durations.append(duration) # Build the plot data with the full dict of dicts structure plot_data = dict( unit_colors=unit_colors, plot_histograms=plot_histograms, bins=bins, - total_duration=total_duration, + durations=durations, unit_ids=unit_ids, hide_unit_selector=hide_unit_selector, plot_legend=plot_legend, @@ -150,8 +150,6 @@ def __init__( first_segment = segment_indices[0] plot_data["spike_train_data"] = spiketrains_by_segment[first_segment] plot_data["y_axis_data"] = amplitudes_by_segment[first_segment] - print(plot_data["spike_train_data"]) - print(plot_data["y_axis_data"]) else: # Otherwise use the full dict of dicts structure with all segments plot_data["spike_train_data"] = spiketrains_by_segment @@ -178,7 +176,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): ] self.view = vv.SpikeAmplitudes( - start_time_sec=0, end_time_sec=dp.total_duration, plots=sa_items, hide_unit_selector=dp.hide_unit_selector + start_time_sec=0, end_time_sec=np.sum(dp.durations), plots=sa_items, hide_unit_selector=dp.hide_unit_selector ) self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 9071698b2f..afdb4a6963 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -256,8 +256,8 @@ def __init__( else: color_kwargs = dict(color=color, c=None, alpha=alpha) - # Calculate total duration for x-axis limits - total_duration = 0 + # Calculate segment durations for x-axis limits + durations = [] for seg_idx in segment_indices: if recording is not None and hasattr(recording, "get_duration"): duration = recording.get_duration(seg_idx) @@ -270,7 +270,7 @@ def __init__( duration = (max_sample + 1) / sampling_frequency else: duration = 0 - total_duration += duration + durations.append(duration) plot_data = dict( spike_train_data=spike_train_data, @@ -281,7 +281,7 @@ def __init__( scatter_decimate=scatter_decimate, title="Peak depth", y_label="Depth [um]", - total_duration=total_duration, + durations=durations, ) BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs) @@ -414,10 +414,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp.recording, ) - commpon_drift_map_kwargs = dict( + common_drift_map_kwargs = dict( direction=dp.motion.direction, recording=dp.recording, - segment_indices=list(dp.segment_index), + segment_indices=[dp.segment_index], depth_lim=dp.depth_lim, scatter_decimate=dp.scatter_decimate, color_amplitude=dp.color_amplitude, @@ -434,7 +434,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp.peak_locations, ax=ax0, immediate_plot=True, - **commpon_drift_map_kwargs, + **common_drift_map_kwargs, ) _ = DriftRasterMapWidget( @@ -442,7 +442,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): corrected_location, ax=ax1, immediate_plot=True, - **commpon_drift_map_kwargs, + **common_drift_map_kwargs, ) ax2.plot(temporal_bins_s, displacement, alpha=0.2, color="black") diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 6012cdaed2..534a30d7d7 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -26,8 +26,8 @@ class BaseRasterWidget(BaseWidget): segment_indices : list | None, default: None For multi-segment data, specifies which segment(s) to plot. If None, uses all available segments. For single-segment data, this parameter is ignored. - total_duration : int | None, default: None - Duration of spike_train_data in seconds. + durations : list | None, default: None + List of durations per segment of spike_train_data in seconds. plot_histograms : bool, default: False Plot histogram of y-axis data in another subplot bins : int | None, default: None @@ -66,7 +66,7 @@ def __init__( y_axis_data: dict, unit_ids: list | None = None, segment_indices: list | None = None, - total_duration: int | None = None, + durations: list | None = None, plot_histograms: bool = False, bins: int | None = None, scatter_decimate: int = 1, @@ -112,67 +112,41 @@ def __init__( all_units.update(spike_train_data[seg_idx].keys()) unit_ids = list(all_units) - # Calculate segment durations and boundaries - segment_durations = [] - for seg_idx in segments_to_use: - max_time = 0 - for unit_id in unit_ids: - if unit_id in spike_train_data[seg_idx]: - unit_times = spike_train_data[seg_idx][unit_id] - if len(unit_times) > 0: - max_time = max(max_time, np.max(unit_times)) - segment_durations.append(max_time) - # Calculate cumulative durations for segment boundaries - cumulative_durations = [0] - for duration in segment_durations[:-1]: - cumulative_durations.append(cumulative_durations[-1] + duration) - - # Segment boundaries for visualization (only internal boundaries) - segment_boundaries = cumulative_durations[1:] if len(segments_to_use) > 1 else None + segment_boundaries = np.cumsum(durations) + cumulative_durations = np.concatenate([[0], segment_boundaries]) # Concatenate data across segments with proper time offsets - concatenated_spike_trains = {unit_id: [] for unit_id in unit_ids} - concatenated_y_axis = {unit_id: [] for unit_id in unit_ids} - - for i, seg_idx in enumerate(segments_to_use): - offset = cumulative_durations[i] - - for unit_id in unit_ids: - if unit_id in spike_train_data[seg_idx]: - # Get spike times for this unit in this segment - spike_times = spike_train_data[seg_idx][unit_id] - - # Adjust spike times by adding cumulative duration of previous segments - if offset > 0: - adjusted_times = spike_times + offset - else: - adjusted_times = spike_times - - # Get y-axis data for this unit in this segment - y_values = y_axis_data[seg_idx][unit_id] - - # Concatenate with any existing data - if len(concatenated_spike_trains[unit_id]) > 0: - concatenated_spike_trains[unit_id] = np.concatenate( - [concatenated_spike_trains[unit_id], adjusted_times] - ) - concatenated_y_axis[unit_id] = np.concatenate([concatenated_y_axis[unit_id], y_values]) - else: - concatenated_spike_trains[unit_id] = adjusted_times - concatenated_y_axis[unit_id] = y_values - - # Update spike train and y-axis data with concatenated values - processed_spike_train_data = concatenated_spike_trains - processed_y_axis_data = concatenated_y_axis - - # Calculate total duration from the data if not provided - if total_duration is None: - total_duration = cumulative_durations[-1] + segment_durations[-1] + concatenated_spike_trains = {unit_id: np.array([]) for unit_id in unit_ids} + concatenated_y_axis = {unit_id: np.array([]) for unit_id in unit_ids} + + for offset, spike_train_segment, y_axis_segment in zip( + cumulative_durations, + [spike_train_data[idx] for idx in segments_to_use], + [y_axis_data[idx] for idx in segments_to_use] + ): + # Process each unit in the current segment + for unit_id, spike_times in spike_train_segment.items(): + if unit_id not in unit_ids: + continue + + # Get y-axis values for this unit + y_values = y_axis_segment[unit_id] + + # Apply offset to spike times + adjusted_times = spike_times + offset + + # Add to concatenated data + concatenated_spike_trains[unit_id] = np.concatenate( + [concatenated_spike_trains[unit_id], adjusted_times] + ) + concatenated_y_axis[unit_id] = np.concatenate( + [concatenated_y_axis[unit_id], y_values] + ) plot_data = dict( - spike_train_data=processed_spike_train_data, - y_axis_data=processed_y_axis_data, + spike_train_data=concatenated_spike_trains, + y_axis_data=concatenated_y_axis, unit_ids=unit_ids, plot_histograms=plot_histograms, y_lim=y_lim, @@ -182,7 +156,7 @@ def __init__( unit_colors=unit_colors, y_label=y_label, title=title, - total_duration=total_duration, + durations=durations, plot_legend=plot_legend, bins=bins, y_ticks=y_ticks, @@ -275,7 +249,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): scatter_ax.set_ylim(*dp.y_lim) x_lim = dp.x_lim if x_lim is None: - x_lim = [0, dp.total_duration] + x_lim = [0, np.sum(dp.durations)] scatter_ax.set_xlim(x_lim) if dp.y_ticks: @@ -432,7 +406,7 @@ def __init__( unit_indices_map = {unit_id: i for i, unit_id in enumerate(unit_ids)} # Calculate total duration across all segments - total_duration = 0 + durations = [] for seg_idx in segment_indices: # Try to get duration from recording if available if recording is not None: @@ -446,7 +420,7 @@ def __init__( max_time = max(max_time, np.max(st)) duration = max_time - total_duration += duration + durations.append(duration) # Initialize dicts for this segment spike_train_data[seg_idx] = {} @@ -486,7 +460,7 @@ def __init__( unit_colors=unit_colors, plot_histograms=None, y_ticks=y_ticks, - total_duration=total_duration, + durations=durations, ) BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs) From 65a52803365359ca8afafb570d65027cedb5fd2c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 30 Apr 2025 10:53:05 +0000 Subject: [PATCH 08/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/amplitudes.py | 5 ++++- src/spikeinterface/widgets/motion.py | 2 +- src/spikeinterface/widgets/rasters.py | 18 ++++++++---------- src/spikeinterface/widgets/utils.py | 11 +++++++---- 4 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 5c0f304bba..3d4d1d41fd 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -176,7 +176,10 @@ def plot_sortingview(self, data_plot, **backend_kwargs): ] self.view = vv.SpikeAmplitudes( - start_time_sec=0, end_time_sec=np.sum(dp.durations), plots=sa_items, hide_unit_selector=dp.hide_unit_selector + start_time_sec=0, + end_time_sec=np.sum(dp.durations), + plots=sa_items, + hide_unit_selector=dp.hide_unit_selector, ) self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index c93a0d2eeb..024baff29a 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -202,7 +202,7 @@ def __init__( segment_indices = [int(unique_segments[0])] else: raise ValueError("segment_indices must be specified if there are multiple segments") - + if not isinstance(segment_indices, list): raise ValueError("segment_indices must be a list of ints") diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 534a30d7d7..55d12e6102 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -113,36 +113,34 @@ def __init__( unit_ids = list(all_units) # Calculate cumulative durations for segment boundaries - segment_boundaries = np.cumsum(durations) - cumulative_durations = np.concatenate([[0], segment_boundaries]) + segment_boundaries = np.cumsum(durations) + cumulative_durations = np.concatenate([[0], segment_boundaries]) # Concatenate data across segments with proper time offsets concatenated_spike_trains = {unit_id: np.array([]) for unit_id in unit_ids} concatenated_y_axis = {unit_id: np.array([]) for unit_id in unit_ids} for offset, spike_train_segment, y_axis_segment in zip( - cumulative_durations, + cumulative_durations, [spike_train_data[idx] for idx in segments_to_use], - [y_axis_data[idx] for idx in segments_to_use] + [y_axis_data[idx] for idx in segments_to_use], ): # Process each unit in the current segment for unit_id, spike_times in spike_train_segment.items(): if unit_id not in unit_ids: continue - + # Get y-axis values for this unit y_values = y_axis_segment[unit_id] - + # Apply offset to spike times adjusted_times = spike_times + offset - + # Add to concatenated data concatenated_spike_trains[unit_id] = np.concatenate( [concatenated_spike_trains[unit_id], adjusted_times] ) - concatenated_y_axis[unit_id] = np.concatenate( - [concatenated_y_axis[unit_id], y_values] - ) + concatenated_y_axis[unit_id] = np.concatenate([concatenated_y_axis[unit_id], y_values]) plot_data = dict( spike_train_data=concatenated_spike_trains, diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index dd9bb20065..898142f515 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -5,6 +5,7 @@ from spikeinterface.core import BaseSorting + def get_some_colors( keys, color_engine="auto", @@ -351,7 +352,8 @@ def make_units_table_from_analyzer( return units_table -def validate_segment_indices(segment_indices: list[int] | None, sorting: BaseSorting): + +def validate_segment_indices(segment_indices: list[int] | None, sorting: BaseSorting): """ Validate a list of segment indices for a sorting object. @@ -382,7 +384,9 @@ def validate_segment_indices(segment_indices: list[int] | None, sorting: BaseSo # Convert segment_index to list for consistent processing if not isinstance(segment_indices, list): - raise ValueError("segment_indices must be a list of ints - available segments are: " + list(range(num_segments))) + raise ValueError( + "segment_indices must be a list of ints - available segments are: " + list(range(num_segments)) + ) # Validate segment indices for idx in segment_indices: @@ -391,5 +395,4 @@ def validate_segment_indices(segment_indices: list[int] | None, sorting: BaseSo if idx < 0 or idx >= num_segments: raise ValueError(f"segment_index {idx} out of range (0 to {num_segments - 1})") - - return segment_indices \ No newline at end of file + return segment_indices From c43fa7c13943236150bfd761d4ba2b023e172f51 Mon Sep 17 00:00:00 2001 From: Jake Swann Date: Mon, 12 May 2025 11:56:31 -0400 Subject: [PATCH 09/17] add test for validate_segment_indices --- .../widgets/tests/test_widgets_utils.py | 34 ++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets_utils.py b/src/spikeinterface/widgets/tests/test_widgets_utils.py index 2131969c2c..de096d0197 100644 --- a/src/spikeinterface/widgets/tests/test_widgets_utils.py +++ b/src/spikeinterface/widgets/tests/test_widgets_utils.py @@ -1,4 +1,7 @@ -from spikeinterface.widgets.utils import get_some_colors +import pytest + +from spikeinterface import generate_sorting +from spikeinterface.widgets.utils import get_some_colors, validate_segment_indices def test_get_some_colors(): @@ -19,5 +22,34 @@ def test_get_some_colors(): # print(colors) +def test_validate_segment_indices(): + # Setup + sorting_single = generate_sorting(durations=[5]) # 1 segment + sorting_multiple = generate_sorting(durations=[5, 10, 15, 20, 25]) # 5 segments + + # Test None with single segment + assert validate_segment_indices(None, sorting_single) == [0] + + # Test None with multiple segments + with pytest.warns(UserWarning): + assert validate_segment_indices(None, sorting_multiple) == [0] + + # Test valid indices + assert validate_segment_indices([0], sorting_single) == [0] + assert validate_segment_indices([0, 1, 4], sorting_multiple) == [0, 1, 4] + + # Test invalid type + with pytest.raises(TypeError): + validate_segment_indices(0, sorting_multiple) + + # Test invalid index type + with pytest.raises(ValueError): + validate_segment_indices([0, "1"], sorting_multiple) + + # Test out of range + with pytest.raises(ValueError): + validate_segment_indices([5], sorting_multiple) + if __name__ == "__main__": test_get_some_colors() + test_validate_segment_indices() From 55e773e81df3879abb1405a9adeb0b73687b2038 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 May 2025 16:24:21 +0000 Subject: [PATCH 10/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../widgets/tests/test_widgets_utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets_utils.py b/src/spikeinterface/widgets/tests/test_widgets_utils.py index de096d0197..db6f1bf537 100644 --- a/src/spikeinterface/widgets/tests/test_widgets_utils.py +++ b/src/spikeinterface/widgets/tests/test_widgets_utils.py @@ -26,30 +26,31 @@ def test_validate_segment_indices(): # Setup sorting_single = generate_sorting(durations=[5]) # 1 segment sorting_multiple = generate_sorting(durations=[5, 10, 15, 20, 25]) # 5 segments - + # Test None with single segment assert validate_segment_indices(None, sorting_single) == [0] - + # Test None with multiple segments with pytest.warns(UserWarning): assert validate_segment_indices(None, sorting_multiple) == [0] - + # Test valid indices assert validate_segment_indices([0], sorting_single) == [0] assert validate_segment_indices([0, 1, 4], sorting_multiple) == [0, 1, 4] - + # Test invalid type with pytest.raises(TypeError): validate_segment_indices(0, sorting_multiple) - + # Test invalid index type with pytest.raises(ValueError): validate_segment_indices([0, "1"], sorting_multiple) - + # Test out of range with pytest.raises(ValueError): validate_segment_indices([5], sorting_multiple) + if __name__ == "__main__": test_get_some_colors() test_validate_segment_indices() From b0985576695aab738f1a9498225f828ae6f318b7 Mon Sep 17 00:00:00 2001 From: Jake Swann Date: Mon, 12 May 2025 12:56:13 -0400 Subject: [PATCH 11/17] simplify segment duration computation --- src/spikeinterface/widgets/amplitudes.py | 7 +-- src/spikeinterface/widgets/motion.py | 27 ++++++------ src/spikeinterface/widgets/rasters.py | 44 +++++++------------ .../widgets/tests/test_widgets_utils.py | 39 +++++++++++++++- src/spikeinterface/widgets/utils.py | 29 ++++++++++++ 5 files changed, 97 insertions(+), 49 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 3d4d1d41fd..6f36baf521 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -5,7 +5,7 @@ from .rasters import BaseRasterWidget from .base import BaseWidget, to_attr -from .utils import get_some_colors, validate_segment_indices +from .utils import get_some_colors, validate_segment_indices, get_segment_durations from spikeinterface.core.sortinganalyzer import SortingAnalyzer @@ -126,10 +126,7 @@ def __init__( bins = 100 # Calculate durations for all segments for x-axis limits - durations = [] - for idx in segment_indices: - duration = sorting_analyzer.get_num_samples(idx) / sorting_analyzer.sampling_frequency - durations.append(duration) + durations = get_segment_durations(sorting) # Build the plot data with the full dict of dicts structure plot_data = dict( diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 024baff29a..6b2936afb7 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -6,6 +6,7 @@ from spikeinterface.core import BaseRecording, SortingAnalyzer from .rasters import BaseRasterWidget +from .utils import get_segment_durations from spikeinterface.core.motion import Motion @@ -259,20 +260,18 @@ def __init__( color_kwargs = dict(color=color, c=None, alpha=alpha) # Calculate segment durations for x-axis limits - durations = [] - for seg_idx in segment_indices: - if recording is not None and hasattr(recording, "get_duration"): - duration = recording.get_duration(seg_idx) - else: - # Estimate from spike times - segment_mask = filtered_peaks["segment_index"] == seg_idx - segment_peaks = filtered_peaks[segment_mask] - if len(segment_peaks) > 0: - max_sample = np.max(segment_peaks["sample_index"]) - duration = (max_sample + 1) / sampling_frequency - else: - duration = 0 - durations.append(duration) + if recording is not None: + durations = [recording.get_duration(seg_idx) for seg_idx in segment_indices] + else: + # Find boundaries between segments using searchsorted + segment_boundaries = [np.searchsorted(filtered_peaks["segment_index"], [seg_idx, seg_idx + 1]) for seg_idx in segment_indices] + + # Calculate durations from max sample in each segment + durations = [ + (np.max(filtered_peaks["sample_index"][start:end]) + 1) / sampling_frequency + if start < end else 0 + for (start, end) in segment_boundaries + ] plot_data = dict( spike_train_data=spike_train_data, diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 55d12e6102..d1452b15cb 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -4,7 +4,7 @@ from warnings import warn from .base import BaseWidget, to_attr, default_backend_kwargs -from .utils import get_some_colors, validate_segment_indices +from .utils import get_some_colors, validate_segment_indices, get_segment_durations class BaseRasterWidget(BaseWidget): @@ -380,14 +380,12 @@ def __init__( backend=None, **backend_kwargs, ): - recording = None if sorting is None and sorting_analyzer is None: raise Exception("Must supply either a sorting or a sorting_analyzer") elif sorting is not None and sorting_analyzer is not None: raise Exception("Should supply either a sorting or a sorting_analyzer, not both") elif sorting_analyzer is not None: sorting = sorting_analyzer.sorting - recording = sorting_analyzer.recording sorting = self.ensure_sorting(sorting) @@ -403,37 +401,25 @@ def __init__( # Create a lookup dictionary for unit indices unit_indices_map = {unit_id: i for i, unit_id in enumerate(unit_ids)} - # Calculate total duration across all segments - durations = [] - for seg_idx in segment_indices: - # Try to get duration from recording if available - if recording is not None: - duration = recording.get_duration(seg_idx) - else: - # Fallback: estimate from max spike time - max_time = 0 - for unit_id in unit_ids: - st = sorting.get_unit_spike_train(unit_id, segment_index=seg_idx, return_times=True) - if len(st) > 0: - max_time = max(max_time, np.max(st)) - duration = max_time + # Get all spikes at once + spikes = sorting.to_spike_vector() - durations.append(duration) + # Estimate segment duration from max spike time in each segment + durations = get_segment_durations(sorting) - # Initialize dicts for this segment - spike_train_data[seg_idx] = {} - y_axis_data[seg_idx] = {} + # Extract spike data for all segments and units at once + spike_train_data = {seg_idx: {} for seg_idx in segment_indices} + y_axis_data = {seg_idx: {} for seg_idx in segment_indices} - # Get spike trains for each unit in this segment + for seg_idx in segment_indices: for unit_id in unit_ids: - spike_times = sorting.get_unit_spike_train(unit_id, segment_index=seg_idx, return_times=True) - - # Store spike trains + # Get spikes for this segment and unit + mask = (spikes['segment_index'] == seg_idx) & (spikes['unit_index'] == unit_id) + spike_times = spikes['sample_index'][mask] / sorting.sampling_frequency + + # Store data spike_train_data[seg_idx][unit_id] = spike_times - - # Create raster locations (y-values for plotting) - unit_index = unit_indices_map[unit_id] - y_axis_data[seg_idx][unit_id] = unit_index * np.ones(len(spike_times)) + y_axis_data[seg_idx][unit_id] = unit_indices_map[unit_id] * np.ones(len(spike_times)) # Apply time range filtering if specified if time_range is not None: diff --git a/src/spikeinterface/widgets/tests/test_widgets_utils.py b/src/spikeinterface/widgets/tests/test_widgets_utils.py index db6f1bf537..e29309bea0 100644 --- a/src/spikeinterface/widgets/tests/test_widgets_utils.py +++ b/src/spikeinterface/widgets/tests/test_widgets_utils.py @@ -1,7 +1,7 @@ import pytest from spikeinterface import generate_sorting -from spikeinterface.widgets.utils import get_some_colors, validate_segment_indices +from spikeinterface.widgets.utils import get_some_colors, validate_segment_indices, get_segment_durations def test_get_some_colors(): @@ -50,7 +50,44 @@ def test_validate_segment_indices(): with pytest.raises(ValueError): validate_segment_indices([5], sorting_multiple) +def test_get_segment_durations(): + from spikeinterface import generate_sorting + + # Test with a normal multi-segment sorting + durations = [5.0, 10.0, 15.0] + + # Create sorting with high fr to ensure spikes near the end segments + sorting = generate_sorting( + durations=durations, + firing_rates=15.0, + ) + + # Calculate durations + calculated_durations = get_segment_durations(sorting) + + # Check results + assert len(calculated_durations) == len(durations) + # Durations should be approximately correct + for calculated_duration, expected_duration in zip(calculated_durations, durations): + # Duration should be <= expected (spikes can't be after the end) + assert calculated_duration <= expected_duration + # And reasonably close + tolerance = max(0.1 * expected_duration, 0.1) + assert expected_duration - calculated_duration < tolerance + + # Test with single-segment sorting + sorting_single = generate_sorting( + durations=[7.0], + firing_rates=15.0, + ) + + single_duration = get_segment_durations(sorting_single)[0] + + # Test that the calculated duration is reasonable + assert single_duration <= 7.0 + assert 7.0 - single_duration < 0.7 # Within 10% if __name__ == "__main__": test_get_some_colors() test_validate_segment_indices() + test_get_segment_durations() diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 898142f515..3e55b71bd5 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -396,3 +396,32 @@ def validate_segment_indices(segment_indices: list[int] | None, sorting: BaseSor raise ValueError(f"segment_index {idx} out of range (0 to {num_segments - 1})") return segment_indices + +def get_segment_durations(sorting: BaseSorting) -> list[float]: + """ + Calculate the duration of each segment in a sorting object. + + Parameters + ---------- + sorting : BaseSorting + The sorting object containing spike data + + Returns + ------- + list[float] + List of segment durations in seconds + """ + spikes = sorting.to_spike_vector() + segment_indices = np.unique(spikes['segment_index']) + + durations = [] + for seg_idx in segment_indices: + segment_mask = spikes['segment_index'] == seg_idx + if np.any(segment_mask): + max_sample = np.max(spikes['sample_index'][segment_mask]) + duration = max_sample / sorting.sampling_frequency + else: + duration = 0 + durations.append(duration) + + return durations \ No newline at end of file From 9dfd1a4150dd6102347f82bcf44b5f25e959787b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 May 2025 16:56:41 +0000 Subject: [PATCH 12/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/motion.py | 9 +++++---- src/spikeinterface/widgets/rasters.py | 6 +++--- .../widgets/tests/test_widgets_utils.py | 16 +++++++++------- src/spikeinterface/widgets/utils.py | 17 +++++++++-------- 4 files changed, 26 insertions(+), 22 deletions(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 6b2936afb7..dbc271f305 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -264,12 +264,13 @@ def __init__( durations = [recording.get_duration(seg_idx) for seg_idx in segment_indices] else: # Find boundaries between segments using searchsorted - segment_boundaries = [np.searchsorted(filtered_peaks["segment_index"], [seg_idx, seg_idx + 1]) for seg_idx in segment_indices] - + segment_boundaries = [ + np.searchsorted(filtered_peaks["segment_index"], [seg_idx, seg_idx + 1]) for seg_idx in segment_indices + ] + # Calculate durations from max sample in each segment durations = [ - (np.max(filtered_peaks["sample_index"][start:end]) + 1) / sampling_frequency - if start < end else 0 + (np.max(filtered_peaks["sample_index"][start:end]) + 1) / sampling_frequency if start < end else 0 for (start, end) in segment_boundaries ] diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index d1452b15cb..4219b34c3d 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -414,9 +414,9 @@ def __init__( for seg_idx in segment_indices: for unit_id in unit_ids: # Get spikes for this segment and unit - mask = (spikes['segment_index'] == seg_idx) & (spikes['unit_index'] == unit_id) - spike_times = spikes['sample_index'][mask] / sorting.sampling_frequency - + mask = (spikes["segment_index"] == seg_idx) & (spikes["unit_index"] == unit_id) + spike_times = spikes["sample_index"][mask] / sorting.sampling_frequency + # Store data spike_train_data[seg_idx][unit_id] = spike_times y_axis_data[seg_idx][unit_id] = unit_indices_map[unit_id] * np.ones(len(spike_times)) diff --git a/src/spikeinterface/widgets/tests/test_widgets_utils.py b/src/spikeinterface/widgets/tests/test_widgets_utils.py index e29309bea0..ff4bfd957c 100644 --- a/src/spikeinterface/widgets/tests/test_widgets_utils.py +++ b/src/spikeinterface/widgets/tests/test_widgets_utils.py @@ -50,21 +50,22 @@ def test_validate_segment_indices(): with pytest.raises(ValueError): validate_segment_indices([5], sorting_multiple) + def test_get_segment_durations(): from spikeinterface import generate_sorting - + # Test with a normal multi-segment sorting durations = [5.0, 10.0, 15.0] - + # Create sorting with high fr to ensure spikes near the end segments sorting = generate_sorting( durations=durations, firing_rates=15.0, ) - + # Calculate durations calculated_durations = get_segment_durations(sorting) - + # Check results assert len(calculated_durations) == len(durations) # Durations should be approximately correct @@ -74,19 +75,20 @@ def test_get_segment_durations(): # And reasonably close tolerance = max(0.1 * expected_duration, 0.1) assert expected_duration - calculated_duration < tolerance - + # Test with single-segment sorting sorting_single = generate_sorting( durations=[7.0], firing_rates=15.0, ) - + single_duration = get_segment_durations(sorting_single)[0] - + # Test that the calculated duration is reasonable assert single_duration <= 7.0 assert 7.0 - single_duration < 0.7 # Within 10% + if __name__ == "__main__": test_get_some_colors() test_validate_segment_indices() diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 3e55b71bd5..9c5892a937 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -397,31 +397,32 @@ def validate_segment_indices(segment_indices: list[int] | None, sorting: BaseSor return segment_indices + def get_segment_durations(sorting: BaseSorting) -> list[float]: """ Calculate the duration of each segment in a sorting object. - + Parameters ---------- sorting : BaseSorting The sorting object containing spike data - + Returns ------- list[float] List of segment durations in seconds """ spikes = sorting.to_spike_vector() - segment_indices = np.unique(spikes['segment_index']) - + segment_indices = np.unique(spikes["segment_index"]) + durations = [] for seg_idx in segment_indices: - segment_mask = spikes['segment_index'] == seg_idx + segment_mask = spikes["segment_index"] == seg_idx if np.any(segment_mask): - max_sample = np.max(spikes['sample_index'][segment_mask]) + max_sample = np.max(spikes["sample_index"][segment_mask]) duration = max_sample / sorting.sampling_frequency else: duration = 0 durations.append(duration) - - return durations \ No newline at end of file + + return durations From 9684e0aaadfd35b7f0a2076484d3c484a9c9b8bb Mon Sep 17 00:00:00 2001 From: Jake Swann Date: Thu, 3 Jul 2025 22:22:41 +0100 Subject: [PATCH 13/17] Address Chris' comments --- src/spikeinterface/widgets/amplitudes.py | 15 +++++++++++++++ src/spikeinterface/widgets/motion.py | 22 ++++++++++++++++++---- src/spikeinterface/widgets/rasters.py | 21 ++++++++++++++++++--- src/spikeinterface/widgets/utils.py | 14 +++++--------- 4 files changed, 56 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 6f36baf521..57fd846ff4 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -60,9 +60,24 @@ def __init__( plot_histograms=False, bins=None, plot_legend=True, + segment_index=None, backend=None, **backend_kwargs, ): + import warnings + # Handle deprecation of segment_index parameter + if segment_index is not None: + warnings.warn( + "The 'segment_index' parameter is deprecated and will be removed in a future version. " + "Use 'segment_indices' instead.", + DeprecationWarning, + stacklevel=2 + ) + if segment_indices is None: + if isinstance(segment_index, int): + segment_indices = [segment_index] + else: + segment_indices = segment_index sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) sorting = sorting_analyzer.sorting diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index dbc271f305..aebaf2340a 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -158,12 +158,28 @@ def __init__( color: str = "Gray", clim: tuple[float, float] | None = None, alpha: float = 1, + segment_index: int | list[int] | None = None, backend: str | None = None, **backend_kwargs, ): + import warnings from matplotlib.pyplot import colormaps from matplotlib.colors import Normalize + # Handle deprecation of segment_index parameter + if segment_index is not None: + warnings.warn( + "The 'segment_index' parameter is deprecated and will be removed in a future version. " + "Use 'segment_indices' instead.", + DeprecationWarning, + stacklevel=2 + ) + if segment_indices is None: + if isinstance(segment_index, int): + segment_indices = [segment_index] + else: + segment_indices = segment_index + assert peaks is not None or sorting_analyzer is not None if peaks is not None: @@ -269,10 +285,8 @@ def __init__( ] # Calculate durations from max sample in each segment - durations = [ - (np.max(filtered_peaks["sample_index"][start:end]) + 1) / sampling_frequency if start < end else 0 - for (start, end) in segment_boundaries - ] + durations = [(filtered_peaks["sample_index"][end-1]+1) / sampling_frequency for (_, end) in segment_boundaries ] + plot_data = dict( spike_train_data=spike_train_data, diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 9926948215..a517eafad7 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -374,16 +374,32 @@ class RasterWidget(BaseRasterWidget): def __init__( self, sorting_analyzer_or_sorting: SortingAnalyzer | BaseSorting | None = None, - segment_index: int | None = None, + segment_indices: int | None = None, unit_ids: list | None = None, time_range: list | None = None, color="k", backend: str | None = None, sorting: BaseSorting | None = None, sorting_analyzer: SortingAnalyzer | None = None, + segment_index: int | None = None, **backend_kwargs, ): + import warnings + # Handle deprecation of segment_index parameter + if segment_index is not None: + warnings.warn( + "The 'segment_index' parameter is deprecated and will be removed in a future version. " + "Use 'segment_indices' instead.", + DeprecationWarning, + stacklevel=2 + ) + if segment_indices is None: + if isinstance(segment_index, int): + segment_indices = [segment_index] + else: + segment_indices = segment_index + if sorting is not None: # When removed, make `sorting_analyzer_or_sorting` a required argument rather than None. deprecation_msg = "`sorting` argument is deprecated and will be removed in version 0.105.0. Please use `sorting_analyzer_or_sorting` instead" @@ -421,8 +437,7 @@ def __init__( for seg_idx in segment_indices: for unit_id in unit_ids: # Get spikes for this segment and unit - mask = (spikes["segment_index"] == seg_idx) & (spikes["unit_index"] == unit_id) - spike_times = spikes["sample_index"][mask] / sorting.sampling_frequency + spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=seg_idx) / sorting.sampling_frequency # Store data spike_train_data[seg_idx][unit_id] = spike_times diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 9c5892a937..53fa527276 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -415,14 +415,10 @@ def get_segment_durations(sorting: BaseSorting) -> list[float]: spikes = sorting.to_spike_vector() segment_indices = np.unique(spikes["segment_index"]) - durations = [] - for seg_idx in segment_indices: - segment_mask = spikes["segment_index"] == seg_idx - if np.any(segment_mask): - max_sample = np.max(spikes["sample_index"][segment_mask]) - duration = max_sample / sorting.sampling_frequency - else: - duration = 0 - durations.append(duration) + segment_boundaries = [ + np.searchsorted(spikes["segment_index"], [seg_idx, seg_idx + 1]) for seg_idx in segment_indices + ] + + durations = [(spikes["sample_index"][end-1] + 1) / sorting.sampling_frequency for (_, end) in segment_boundaries] return durations From 6c20215d322ab72e31e0d1693a72e14cd6dbe01c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Jul 2025 21:24:30 +0000 Subject: [PATCH 14/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/amplitudes.py | 3 ++- src/spikeinterface/widgets/motion.py | 7 ++++--- src/spikeinterface/widgets/rasters.py | 7 +++++-- src/spikeinterface/widgets/utils.py | 8 ++++---- 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 57fd846ff4..1f4c1eab47 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -65,13 +65,14 @@ def __init__( **backend_kwargs, ): import warnings + # Handle deprecation of segment_index parameter if segment_index is not None: warnings.warn( "The 'segment_index' parameter is deprecated and will be removed in a future version. " "Use 'segment_indices' instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) if segment_indices is None: if isinstance(segment_index, int): diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index aebaf2340a..eac1b2713c 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -172,7 +172,7 @@ def __init__( "The 'segment_index' parameter is deprecated and will be removed in a future version. " "Use 'segment_indices' instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) if segment_indices is None: if isinstance(segment_index, int): @@ -285,8 +285,9 @@ def __init__( ] # Calculate durations from max sample in each segment - durations = [(filtered_peaks["sample_index"][end-1]+1) / sampling_frequency for (_, end) in segment_boundaries ] - + durations = [ + (filtered_peaks["sample_index"][end - 1] + 1) / sampling_frequency for (_, end) in segment_boundaries + ] plot_data = dict( spike_train_data=spike_train_data, diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index a517eafad7..55032b38eb 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -386,13 +386,14 @@ def __init__( ): import warnings + # Handle deprecation of segment_index parameter if segment_index is not None: warnings.warn( "The 'segment_index' parameter is deprecated and will be removed in a future version. " "Use 'segment_indices' instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) if segment_indices is None: if isinstance(segment_index, int): @@ -437,7 +438,9 @@ def __init__( for seg_idx in segment_indices: for unit_id in unit_ids: # Get spikes for this segment and unit - spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=seg_idx) / sorting.sampling_frequency + spike_times = ( + sorting.get_unit_spike_train(unit_id=unit_id, segment_index=seg_idx) / sorting.sampling_frequency + ) # Store data spike_train_data[seg_idx][unit_id] = spike_times diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 53fa527276..a003742939 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -415,10 +415,10 @@ def get_segment_durations(sorting: BaseSorting) -> list[float]: spikes = sorting.to_spike_vector() segment_indices = np.unique(spikes["segment_index"]) - segment_boundaries = [ - np.searchsorted(spikes["segment_index"], [seg_idx, seg_idx + 1]) for seg_idx in segment_indices - ] + segment_boundaries = [ + np.searchsorted(spikes["segment_index"], [seg_idx, seg_idx + 1]) for seg_idx in segment_indices + ] - durations = [(spikes["sample_index"][end-1] + 1) / sorting.sampling_frequency for (_, end) in segment_boundaries] + durations = [(spikes["sample_index"][end - 1] + 1) / sorting.sampling_frequency for (_, end) in segment_boundaries] return durations From 032bdb558c9c0f98d8c804b48e801f22f93ee1e1 Mon Sep 17 00:00:00 2001 From: Jake Swann Date: Sun, 6 Jul 2025 22:50:24 +0100 Subject: [PATCH 15/17] address Chris' comments --- src/spikeinterface/widgets/amplitudes.py | 2 +- src/spikeinterface/widgets/motion.py | 1 - src/spikeinterface/widgets/rasters.py | 5 +---- src/spikeinterface/widgets/tests/test_widgets_utils.py | 7 +++++-- src/spikeinterface/widgets/utils.py | 4 ++-- 5 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 1f4c1eab47..c3aeb221ab 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -142,7 +142,7 @@ def __init__( bins = 100 # Calculate durations for all segments for x-axis limits - durations = get_segment_durations(sorting) + durations = get_segment_durations(sorting, segment_indices) # Build the plot data with the full dict of dicts structure plot_data = dict( diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index eac1b2713c..3c08217300 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -6,7 +6,6 @@ from spikeinterface.core import BaseRecording, SortingAnalyzer from .rasters import BaseRasterWidget -from .utils import get_segment_durations from spikeinterface.core.motion import Motion diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 55032b38eb..757401d77c 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -425,11 +425,8 @@ def __init__( # Create a lookup dictionary for unit indices unit_indices_map = {unit_id: i for i, unit_id in enumerate(unit_ids)} - # Get all spikes at once - spikes = sorting.to_spike_vector() - # Estimate segment duration from max spike time in each segment - durations = get_segment_durations(sorting) + durations = get_segment_durations(sorting, segment_indices) # Extract spike data for all segments and units at once spike_train_data = {seg_idx: {} for seg_idx in segment_indices} diff --git a/src/spikeinterface/widgets/tests/test_widgets_utils.py b/src/spikeinterface/widgets/tests/test_widgets_utils.py index ff4bfd957c..dfb611119d 100644 --- a/src/spikeinterface/widgets/tests/test_widgets_utils.py +++ b/src/spikeinterface/widgets/tests/test_widgets_utils.py @@ -1,3 +1,4 @@ +from matplotlib.lines import segment_hits import pytest from spikeinterface import generate_sorting @@ -63,8 +64,10 @@ def test_get_segment_durations(): firing_rates=15.0, ) + segment_indices = list(range(sorting.get_num_segments())) + # Calculate durations - calculated_durations = get_segment_durations(sorting) + calculated_durations = get_segment_durations(sorting, segment_indices) # Check results assert len(calculated_durations) == len(durations) @@ -82,7 +85,7 @@ def test_get_segment_durations(): firing_rates=15.0, ) - single_duration = get_segment_durations(sorting_single)[0] + single_duration = get_segment_durations(sorting_single, [0])[0] # Test that the calculated duration is reasonable assert single_duration <= 7.0 diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index a003742939..3adcc05636 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -4,6 +4,7 @@ import numpy as np from spikeinterface.core import BaseSorting +from traitlets import Int def get_some_colors( @@ -398,7 +399,7 @@ def validate_segment_indices(segment_indices: list[int] | None, sorting: BaseSor return segment_indices -def get_segment_durations(sorting: BaseSorting) -> list[float]: +def get_segment_durations(sorting: BaseSorting, segment_indices: list[int]) -> list[float]: """ Calculate the duration of each segment in a sorting object. @@ -413,7 +414,6 @@ def get_segment_durations(sorting: BaseSorting) -> list[float]: List of segment durations in seconds """ spikes = sorting.to_spike_vector() - segment_indices = np.unique(spikes["segment_index"]) segment_boundaries = [ np.searchsorted(spikes["segment_index"], [seg_idx, seg_idx + 1]) for seg_idx in segment_indices From d889d41062cd6415335435e8d34e1621d2b3d276 Mon Sep 17 00:00:00 2001 From: Jake Swann Date: Sun, 6 Jul 2025 22:56:33 +0100 Subject: [PATCH 16/17] oops --- src/spikeinterface/widgets/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 3adcc05636..26b5ec2e59 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -4,8 +4,6 @@ import numpy as np from spikeinterface.core import BaseSorting -from traitlets import Int - def get_some_colors( keys, From 260a951fe8e144645cb66f4c5dab8eeae33e2323 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 6 Jul 2025 21:57:05 +0000 Subject: [PATCH 17/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 26b5ec2e59..75fb74cfae 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -5,6 +5,7 @@ from spikeinterface.core import BaseSorting + def get_some_colors( keys, color_engine="auto",