Skip to content

Commit a321af2

Browse files
committed
fix: check if periods and valid periods are match
1 parent 1c29d0e commit a321af2

2 files changed

Lines changed: 19 additions & 2 deletions

File tree

src/spikeinterface/metrics/quality/quality_metrics.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from spikeinterface.core.template_tools import get_template_extremum_channel
77
from spikeinterface.core.sortinganalyzer import register_result_extension
88
from spikeinterface.core.analyzer_extension_core import BaseMetricExtension
9+
from spikeinterface.core.sorting_tools import cast_periods_to_unit_period_dtype
910

1011
from .misc_metrics import misc_metrics_list
1112
from .pca_metrics import pca_metrics_list
@@ -115,7 +116,12 @@ def _set_params(
115116
metric_names = [m for m in metric_names if m not in pc_metric_names]
116117

117118
if use_valid_periods:
118-
periods = self.sorting_analyzer.get_extension("valid_unit_periods").get_data(outputs="numpy")
119+
valid_periods = self.sorting_analyzer.get_extension("valid_unit_periods").get_data(outputs="numpy")
120+
if periods is not None:
121+
provided_periods = cast_periods_to_unit_period_dtype(np.asarray(periods))
122+
if not np.array_equal(valid_periods, provided_periods):
123+
raise ValueError("Provided periods do not match valid periods from the sorting analyzer.")
124+
periods = valid_periods
119125

120126
return super()._set_params(
121127
metric_names=metric_names,

src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def test_quality_metrics_with_periods():
218218
seed=2205,
219219
)
220220

221-
# test failure when both periods and use_valid_periods are set
221+
# test failure when periods and valid_unit_periods do not match
222222
with pytest.raises(ValueError):
223223
compute_quality_metrics(
224224
sorting_analyzer,
@@ -229,6 +229,17 @@ def test_quality_metrics_with_periods():
229229
seed=2205,
230230
)
231231

232+
# should not fail if external periods are the same as valid unit periods
233+
valid_periods = sorting_analyzer.get_extension("valid_unit_periods").get_data(outputs="numpy")
234+
metrics_ext_periods = compute_quality_metrics(
235+
sorting_analyzer,
236+
metric_names=None,
237+
skip_pc_metrics=True,
238+
use_valid_periods=True,
239+
periods=valid_periods,
240+
seed=2205,
241+
)
242+
232243
# test failure if use valid_periods is True but valid_unit_periods extension is missing
233244
sorting_analyzer.delete_extension("valid_unit_periods")
234245
with pytest.raises(AssertionError):

0 commit comments

Comments
 (0)