diff --git a/src/spikeinterface/metrics/quality/quality_metrics.py b/src/spikeinterface/metrics/quality/quality_metrics.py index 399908ec03..a72e3d3775 100644 --- a/src/spikeinterface/metrics/quality/quality_metrics.py +++ b/src/spikeinterface/metrics/quality/quality_metrics.py @@ -6,6 +6,7 @@ from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseMetricExtension +from spikeinterface.core.sorting_tools import cast_periods_to_unit_period_dtype from .misc_metrics import misc_metrics_list from .pca_metrics import pca_metrics_list @@ -115,15 +116,19 @@ def _set_params( metric_names = [m for m in metric_names if m not in pc_metric_names] if use_valid_periods: + valid_periods = self.sorting_analyzer.get_extension("valid_unit_periods").get_data(outputs="numpy") if periods is not None: - raise ValueError("If use_valid_periods is True, periods should not be provided.") - periods = self.sorting_analyzer.get_extension("valid_unit_periods").get_data(outputs="numpy") + provided_periods = cast_periods_to_unit_period_dtype(np.asarray(periods)) + if not np.array_equal(valid_periods, provided_periods): + raise ValueError("Provided periods do not match valid periods from the sorting analyzer.") + periods = valid_periods return super()._set_params( metric_names=metric_names, metric_params=metric_params, delete_existing_metrics=delete_existing_metrics, metrics_to_compute=metrics_to_compute, + use_valid_periods=use_valid_periods, periods=periods, peak_sign=peak_sign, seed=seed, diff --git a/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py b/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py index dfd47c4df9..8566217d0b 100644 --- a/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py @@ -218,7 +218,7 @@ def test_quality_metrics_with_periods(): seed=2205, ) - # test failure when both periods and use_valid_periods are set + # test failure when periods and valid_unit_periods do not match with pytest.raises(ValueError): compute_quality_metrics( sorting_analyzer, @@ -229,6 +229,17 @@ def test_quality_metrics_with_periods(): seed=2205, ) + # should not fail if external periods are the same as valid unit periods + valid_periods = sorting_analyzer.get_extension("valid_unit_periods").get_data(outputs="numpy") + metrics_ext_periods = compute_quality_metrics( + sorting_analyzer, + metric_names=None, + skip_pc_metrics=True, + use_valid_periods=True, + periods=valid_periods, + seed=2205, + ) + # test failure if use valid_periods is True but valid_unit_periods extension is missing sorting_analyzer.delete_extension("valid_unit_periods") with pytest.raises(AssertionError):