Skip to content
30 changes: 25 additions & 5 deletions src/spikeinterface/widgets/isi_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import numpy as np
from warnings import warn

from spikeinterface.core import SortingAnalyzer, BaseSorting

from .base import BaseWidget, to_attr
from .utils import get_unit_colors


class ISIDistributionWidget(BaseWidget):
Expand All @@ -13,18 +14,37 @@ class ISIDistributionWidget(BaseWidget):

Parameters
----------
sorting : SortingExtractor
The sorting extractor object
sorting_analyzer_or_sorting : SortingAnalyzer | BaseSorting | None = None
The object containing the sorting information for the isi distribution plot
unit_ids : list
List of unit ids
bins_ms : int
Bin size in ms
window_ms : float
Window size in ms

sorting : BaseSorting | None = None
Deprecated argument.
"""

def __init__(self, sorting, unit_ids=None, window_ms=100.0, bin_ms=1.0, backend=None, **backend_kwargs):
def __init__(
self,
sorting_analyzer_or_sorting: SortingAnalyzer | BaseSorting | None = None,
unit_ids: list | None = None,
window_ms: float = 100.0,
bin_ms: float = 1.0,
backend: str | None = None,
sorting: BaseSorting | None = None,
**backend_kwargs,
):

if sorting is not None:
# When removed, make `sorting_analyzer_or_sorting` a required argument rather than None.
deprecation_msg = "`sorting` argument is deprecated and will be removed in version 0.105.0. Please use `sorting_analyzer_or_sorting` instead"
warn(deprecation_msg, category=DeprecationWarning, stacklevel=2)
sorting_analyzer_or_sorting = sorting

sorting = self.ensure_sorting(sorting_analyzer_or_sorting)

if unit_ids is None:
unit_ids = sorting.get_unit_ids()

Expand Down