Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions src/spikeinterface/sorters/external/dartsort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from pathlib import Path
from packaging.version import parse

from ..basesorter import BaseSorter
from ...core import NumpyFolderSorting, NumpySorting

class DartsortSorter(BaseSorter):
"""Dartsort wrapper"""

sorter_name = "dartsort"
requires_locations = False
compatible_with_parallel = {"loky": False, "multiprocessing": False, "threading": False}
sorter_description = "Dartsort is the Columbia university sorter made with love by Charlie Windolf and Liam Paninski's team."
installation_mesg = """\nTo use dartsort run:\n
>>> pip install dartsort

More information on mountainsort5 at:
* https://github.com/cwindolf/dartsort
"""

_default_params = {
}

_params_description = {
}

@classmethod
def _dynamic_params(cls):
from dartsort import DARTsortUserConfig
from pydantic import RootModel
# the trick is to transform the DARTsortUserConfig (a pydantic.dataclass) into a pydantic model
Model = RootModel[DARTsortUserConfig]
# so we can dump to dict
cfg = Model(DARTsortUserConfig())
default_params = cfg.model_dump(mode='python')
# and retrieve properties
schema = Model.model_json_schema()
default_params_descriptions = {}
for k, props in schema['$defs']['DARTsortUserConfig']['properties'].items():
default_params_descriptions[k] = props['title']

return default_params, default_params_descriptions

@classmethod
def is_installed(cls):
try:
import dartsort
HAVE_DARTSORT = True
except ImportError:
HAVE_DARTSORT = False

return HAVE_DARTSORT

@staticmethod
def get_sorter_version():
import dartsort
if hasattr(dartsort, "__version__"):
return dartsort.__version__
return "unknown"

@classmethod
def _setup_recording(cls, recording, sorter_output_folder, params, verbose):
pass

@classmethod
def _run_from_folder(cls, sorter_output_folder, params, verbose):
from dartsort import dartsort as dartsort_main
from dartsort import DARTsortUserConfig

recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False)

# dartsort config are set using dataclass we need to map this
cfg = DARTsortUserConfig(**params)

ret = dartsort_main(
recording,
sorter_output_folder,
cfg,
)
# the dartsort_sorting is not the spikeinterface sorting!!!
dartsort_sorting = ret['sorting']

times_samples = dartsort_sorting.times_samples
labels = dartsort_sorting.labels
mask = labels >= 0

sorting = NumpySorting.from_samples_and_labels(
[times_samples[mask]], [labels[mask]], dartsort_sorting.sampling_frequency
)

NumpyFolderSorting.write_sorting(sorting, sorter_output_folder / "final_darsort_sorting")

@classmethod
def _get_result_from_folder(cls, sorter_output_folder):
sorter_output_folder = Path(sorter_output_folder)
sorting = NumpyFolderSorting(sorter_output_folder / "final_darsort_sorting")
return sorting
18 changes: 18 additions & 0 deletions src/spikeinterface/sorters/external/tests/test_dartsort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import unittest
import pytest

from spikeinterface.sorters import DartsortSorter
from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite


@pytest.mark.skipif(not DartsortSorter.is_installed(), reason="dartsort not installed")
class DartsortCommonTestSuite(SorterCommonTestSuite, unittest.TestCase):
SorterClass = DartsortSorter


if __name__ == "__main__":
from pathlib import Path
test = DartsortCommonTestSuite()
test.cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "sorters"
test.setUp()
test.test_with_run()
2 changes: 2 additions & 0 deletions src/spikeinterface/sorters/sorterlist.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .external.combinato import CombinatoSorter
from .external.dartsort import DartsortSorter
from .external.hdsort import HDSortSorter
from .external.herdingspikes import HerdingspikesSorter
from .external.ironclust import IronClustSorter
Expand Down Expand Up @@ -27,6 +28,7 @@
sorter_full_list = [
# external
CombinatoSorter,
DartsortSorter,
HDSortSorter,
HerdingspikesSorter,
IronClustSorter,
Expand Down
Loading