File tree Expand file tree Collapse file tree
src/spikeinterface/metrics/quality Expand file tree Collapse file tree Original file line number Diff line number Diff line change 66from spikeinterface .core .template_tools import get_template_extremum_channel
77from spikeinterface .core .sortinganalyzer import register_result_extension
88from spikeinterface .core .analyzer_extension_core import BaseMetricExtension
9+ from spikeinterface .core .sorting_tools import cast_periods_to_unit_period_dtype
910
1011from .misc_metrics import misc_metrics_list
1112from .pca_metrics import pca_metrics_list
@@ -115,7 +116,12 @@ def _set_params(
115116 metric_names = [m for m in metric_names if m not in pc_metric_names ]
116117
117118 if use_valid_periods :
118- periods = self .sorting_analyzer .get_extension ("valid_unit_periods" ).get_data (outputs = "numpy" )
119+ valid_periods = self .sorting_analyzer .get_extension ("valid_unit_periods" ).get_data (outputs = "numpy" )
120+ if periods is not None :
121+ provided_periods = cast_periods_to_unit_period_dtype (np .asarray (periods ))
122+ if not np .array_equal (valid_periods , provided_periods ):
123+ raise ValueError ("Provided periods do not match valid periods from the sorting analyzer." )
124+ periods = valid_periods
119125
120126 return super ()._set_params (
121127 metric_names = metric_names ,
Original file line number Diff line number Diff line change @@ -218,7 +218,7 @@ def test_quality_metrics_with_periods():
218218 seed = 2205 ,
219219 )
220220
221- # test failure when both periods and use_valid_periods are set
221+ # test failure when periods and valid_unit_periods do not match
222222 with pytest .raises (ValueError ):
223223 compute_quality_metrics (
224224 sorting_analyzer ,
@@ -229,6 +229,17 @@ def test_quality_metrics_with_periods():
229229 seed = 2205 ,
230230 )
231231
232+ # should not fail if external periods are the same as valid unit periods
233+ valid_periods = sorting_analyzer .get_extension ("valid_unit_periods" ).get_data (outputs = "numpy" )
234+ metrics_ext_periods = compute_quality_metrics (
235+ sorting_analyzer ,
236+ metric_names = None ,
237+ skip_pc_metrics = True ,
238+ use_valid_periods = True ,
239+ periods = valid_periods ,
240+ seed = 2205 ,
241+ )
242+
232243 # test failure if use valid_periods is True but valid_unit_periods extension is missing
233244 sorting_analyzer .delete_extension ("valid_unit_periods" )
234245 with pytest .raises (AssertionError ):
You can’t perform that action at this time.
0 commit comments