Skip to content

Commit 7ac2ae0

Browse files
authored
Add install check for components sorters (#4574)
1 parent 25c7dfc commit 7ac2ae0

4 files changed

Lines changed: 67 additions & 2 deletions

File tree

pyproject.toml

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,26 @@ widgets = [
168168
"distinctipy",
169169
]
170170

171+
lupin = [
172+
"scipy",
173+
"numba",
174+
"scikit-learn",
175+
"torch"
176+
]
177+
178+
spykingcircus2 = [
179+
"scipy",
180+
"hdbscan",
181+
"numba",
182+
]
183+
184+
tridesclous2= [
185+
"scipy",
186+
"numba",
187+
"scikit-learn",
188+
"torch",
189+
]
190+
171191
# `full` installs every module's optional feature deps. Defined as the union of
172192
# per-module extras so adding a dep to a module propagates here automatically.
173193
full = [
@@ -272,8 +292,11 @@ test-comparison = [
272292

273293
test-sorters-internal = [
274294
{include-group = "test-common"},
275-
"torch", # spyking_circus2 template matching
276-
"hdbscan>=0.8.33", # simplesorter / tridesclous2
295+
"scipy",
296+
"torch",
297+
"hdbscan>=0.8.33",
298+
"numba",
299+
"scikit-learn",
277300
]
278301
test-sorters = [{include-group = "test-sorters-internal"}]
279302
test-curation = [{include-group = "test-common"}]

src/spikeinterface/sorters/internal/lupin.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,22 @@ class LupinSorter(ComponentsBasedSorter):
102102
"debug": "Save debug files",
103103
}
104104

105+
installation_mesg = "\tpip install 'spikeinterface[lupin]'\nOr, if you have cloned SpikeInterface locally, using:\n\tpip install '.[lupin]'"
106+
105107
handle_multi_segment = True
106108

109+
@classmethod
110+
def is_installed(cls):
111+
import importlib.util
112+
113+
lupin_deps = ["scipy", "numba", "sklearn", "torch"]
114+
115+
for package_name in lupin_deps:
116+
if not importlib.util.find_spec(package_name):
117+
return False
118+
119+
return True
120+
107121
@classmethod
108122
def get_sorter_version(cls):
109123
return "2026.01"

src/spikeinterface/sorters/internal/spyking_circus2.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,20 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
8686
In addition, it also uses a full Orthogonal Matching Pursuit engine to reconstruct the traces, leading to more spikes
8787
being discovered. The code is much faster and memory efficient, inheriting from all the preprocessing possibilities of spikeinterface"""
8888

89+
installation_mesg = "\tpip install 'spikeinterface[spykingcircus2]'\nOr, if you have cloned SpikeInterface locally, using:\n\tpip install '.[spykingcircus2]'"
90+
91+
@classmethod
92+
def is_installed(cls):
93+
import importlib.util
94+
95+
spykingcircus2_deps = ["scipy", "numba", "hdbscan"]
96+
97+
for package_name in spykingcircus2_deps:
98+
if not importlib.util.find_spec(package_name):
99+
return False
100+
101+
return True
102+
89103
@classmethod
90104
def get_sorter_version(cls):
91105
return "2025.12"

src/spikeinterface/sorters/internal/tridesclous2.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,20 @@ class Tridesclous2Sorter(ComponentsBasedSorter):
9797

9898
handle_multi_segment = True
9999

100+
installation_mesg = "\tpip install 'spikeinterface[tridesclous2]'\nOr, if you have cloned SpikeInterface locally, using:\n\tpip install '.[tridesclous2]'"
101+
102+
@classmethod
103+
def is_installed(cls):
104+
import importlib.util
105+
106+
tridesclous2_deps = ["scipy", "numba", "hdbscan"]
107+
108+
for package_name in tridesclous2_deps:
109+
if not importlib.util.find_spec(package_name):
110+
return False
111+
112+
return True
113+
100114
@classmethod
101115
def get_sorter_version(cls):
102116
return "2026.01"

0 commit comments

Comments
 (0)