11import pytest
22from pathlib import Path
33
4- from spikeinterface .curation .tests .common import sorting_analyzer_for_curation , trained_pipeline_path
4+ from spikeinterface .curation .tests .common import sorting_analyzer_for_unitrefine_curation , trained_pipeline_path
55from spikeinterface .curation .model_based_curation import ModelBasedClassification
66from spikeinterface .curation import model_based_label_units , load_model
77
1616
1717@pytest .fixture
1818def model (trained_pipeline_path ):
19- """A toy model, created using the `sorting_analyzer_for_curation ` from `spikeinterface.curation.tests.common`.
20- It has been trained locally and, when applied to `sorting_analyzer_for_curation ` will label its 5 units with
21- the following labels: [1,0 ,1,0,1 ]."""
19+ """A toy model, created using the `sorting_analyzer_for_unitrefine_curation ` from `spikeinterface.curation.tests.common`.
20+ It has been trained locally and, when applied to `sorting_analyzer_for_unitrefine_curation ` will label its 10 units with
21+ the following labels: [1,1 ,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0 ]."""
2222
2323 model = load_model (trained_pipeline_path , trusted = ["numpy.dtype" ])
2424 return model
@@ -30,32 +30,36 @@ def required_metrics_and_columns():
3030 return ["num_spikes" , "snr" , "half_width" ], ["num_spikes" , "snr" , "trough_half_width" , "peak_half_width" ]
3131
3232
33- def test_model_based_classification_init (sorting_analyzer_for_curation , model ):
33+ def test_model_based_classification_init (sorting_analyzer_for_unitrefine_curation , model ):
3434 """Test that the ModelBasedClassification attributes are correctly initialised"""
3535
36- model_based_classification = ModelBasedClassification (sorting_analyzer_for_curation , model [0 ])
37- assert model_based_classification .sorting_analyzer == sorting_analyzer_for_curation
36+ model_based_classification = ModelBasedClassification (sorting_analyzer_for_unitrefine_curation , model [0 ])
37+ assert model_based_classification .sorting_analyzer == sorting_analyzer_for_unitrefine_curation
3838 assert model_based_classification .pipeline == model [0 ]
3939 assert np .all (model_based_classification .required_metrics == model_based_classification .pipeline .feature_names_in_ )
4040
4141
42- def test_metric_ordering_independence (sorting_analyzer_for_curation , trained_pipeline_path ):
42+ def test_metric_ordering_independence (sorting_analyzer_for_unitrefine_curation , trained_pipeline_path ):
4343 """The function `model_based_label_units` needs the correct metrics to have been computed. However,
4444 it should be independent of the order of computation. We test this here."""
4545
46- sorting_analyzer_for_curation .compute ("template_metrics" , metric_names = ["half_width" ])
47- sorting_analyzer_for_curation .compute ("quality_metrics" , metric_names = ["num_spikes" , "snr" ])
46+ sorting_analyzer_for_unitrefine_curation .compute (
47+ "template_metrics" , metric_names = ["half_width" , "peak_to_trough_duration" , "number_of_peaks" ]
48+ )
49+ sorting_analyzer_for_unitrefine_curation .compute ("quality_metrics" , metric_names = ["snr" ])
4850
4951 prediction_prob_dataframe_1 = model_based_label_units (
50- sorting_analyzer = sorting_analyzer_for_curation ,
52+ sorting_analyzer = sorting_analyzer_for_unitrefine_curation ,
5153 model_folder = trained_pipeline_path ,
5254 trusted = ["numpy.dtype" ],
5355 )
5456
55- sorting_analyzer_for_curation .compute ("quality_metrics" , metric_names = ["snr" , "num_spikes" ])
57+ sorting_analyzer_for_unitrefine_curation .compute (
58+ "template_metrics" , metric_names = ["peak_to_trough_duration" , "half_width" , "number_of_peaks" ]
59+ )
5660
5761 prediction_prob_dataframe_2 = model_based_label_units (
58- sorting_analyzer = sorting_analyzer_for_curation ,
62+ sorting_analyzer = sorting_analyzer_for_unitrefine_curation ,
5963 model_folder = trained_pipeline_path ,
6064 trusted = ["numpy.dtype" ],
6165 )
@@ -64,40 +68,40 @@ def test_metric_ordering_independence(sorting_analyzer_for_curation, trained_pip
6468
6569
6670def test_model_based_classification_get_metrics_for_classification (
67- sorting_analyzer_for_curation , model , required_metrics_and_columns
71+ sorting_analyzer_for_unitrefine_curation , model , required_metrics_and_columns
6872):
6973 """If the user has not computed the required metrics, an error should be returned.
7074 This test checks that an error occurs when the required metrics have not been computed,
7175 and that no error is returned when the required metrics have been computed.
7276 """
7377
74- sorting_analyzer_for_curation .delete_extension ("quality_metrics" )
75- sorting_analyzer_for_curation .delete_extension ("template_metrics" )
78+ sorting_analyzer_for_unitrefine_curation .delete_extension ("quality_metrics" )
79+ sorting_analyzer_for_unitrefine_curation .delete_extension ("template_metrics" )
7680
7781 required_metric_names , required_metric_columns = required_metrics_and_columns
7882
79- model_based_classification = ModelBasedClassification (sorting_analyzer_for_curation , model [0 ])
83+ model_based_classification = ModelBasedClassification (sorting_analyzer_for_unitrefine_curation , model [0 ])
8084
8185 # Compute some (but not all) of the required metrics in sorting_analyzer, should still error
82- sorting_analyzer_for_curation .compute ("quality_metrics" , metric_names = [required_metric_names [0 ]])
83- computed_metrics = sorting_analyzer_for_curation .get_metrics_extension_data ()
86+ sorting_analyzer_for_unitrefine_curation .compute ("quality_metrics" , metric_names = [required_metric_names [0 ]])
87+ computed_metrics = sorting_analyzer_for_unitrefine_curation .get_metrics_extension_data ()
8488 with pytest .raises (ValueError ):
8589 model_based_classification ._check_required_metrics_are_present (computed_metrics )
8690
8791 # Compute all of the required metrics in sorting_analyzer, no more error
88- sorting_analyzer_for_curation .compute ("quality_metrics" , metric_names = required_metric_names [0 :2 ])
89- sorting_analyzer_for_curation .compute ("template_metrics" , metric_names = [required_metric_names [2 ]])
92+ sorting_analyzer_for_unitrefine_curation .compute ("quality_metrics" , metric_names = required_metric_names [0 :2 ])
93+ sorting_analyzer_for_unitrefine_curation .compute ("template_metrics" , metric_names = [required_metric_names [2 ]])
9094
91- metrics_data = sorting_analyzer_for_curation .get_metrics_extension_data ()
92- assert metrics_data .shape [0 ] == len (sorting_analyzer_for_curation .sorting .get_unit_ids ())
95+ metrics_data = sorting_analyzer_for_unitrefine_curation .get_metrics_extension_data ()
96+ assert metrics_data .shape [0 ] == len (sorting_analyzer_for_unitrefine_curation .sorting .get_unit_ids ())
9397 assert set (metrics_data .columns .to_list ()) == set (required_metric_columns )
9498
9599
96- def test_model_based_classification_export_to_phy (sorting_analyzer_for_curation , model ):
100+ def test_model_based_classification_export_to_phy (sorting_analyzer_for_unitrefine_curation , model ):
97101 import pandas as pd
98102
99103 # Test the _export_to_phy() method of ModelBasedClassification
100- model_based_classification = ModelBasedClassification (sorting_analyzer_for_curation , model [0 ])
104+ model_based_classification = ModelBasedClassification (sorting_analyzer_for_unitrefine_curation , model [0 ])
101105
102106 classified_units = pd .DataFrame .from_dict ({0 : (1 , 0.5 ), 1 : (0 , 0.5 ), 2 : (1 , 0.5 ), 3 : (0 , 0.5 ), 4 : (1 , 0.5 )})
103107 # Function should fail here
@@ -112,42 +116,48 @@ def test_model_based_classification_export_to_phy(sorting_analyzer_for_curation,
112116 assert (phy_folder / "cluster_prediction.tsv" ).exists ()
113117
114118
115- def test_model_based_classification_predict_labels (sorting_analyzer_for_curation , model ):
119+ def test_model_based_classification_predict_labels (sorting_analyzer_for_unitrefine_curation , model ):
116120 """The model `model` has been trained on the `sorting_analyzer` used in this test with
117121 the labels `[1, 0, 1, 0, 1]`. Hence if we apply the model to this `sorting_analyzer`
118122 we expect these labels to be outputted. The test checks this, and also checks
119123 that label conversion works as expected."""
120124
121- sorting_analyzer_for_curation .compute ("template_metrics" , metric_names = ["half_width" ])
122- sorting_analyzer_for_curation .compute ("quality_metrics" , metric_names = ["num_spikes" , "snr" ])
125+ sorting_analyzer_for_unitrefine_curation .compute (
126+ "template_metrics" , metric_names = ["half_width" , "peak_to_trough_duration" , "number_of_peaks" ]
127+ )
128+ sorting_analyzer_for_unitrefine_curation .compute ("quality_metrics" , metric_names = ["num_spikes" , "snr" ])
123129
124130 # Test the predict_labels() method of ModelBasedClassification
125- model_based_classification = ModelBasedClassification (sorting_analyzer_for_curation , model [0 ])
131+ model_based_classification = ModelBasedClassification (sorting_analyzer_for_unitrefine_curation , model [0 ])
126132 classified_units = model_based_classification .predict_labels ()
127133 predictions = classified_units ["prediction" ].values
128134
129- assert np .all (predictions == np .array ([1 , 0 , 1 , 0 , 1 ]))
135+ expected_result = np .array ([1 ] * 10 + [0 ] * 10 )
136+ assert np .all (predictions == expected_result )
130137
131138 conversion = {0 : "noise" , 1 : "good" }
139+ expected_result_converted = np .array (["good" ] * 10 + ["noise" ] * 10 )
132140 classified_units_labelled = model_based_classification .predict_labels (label_conversion = conversion )
133141 predictions_labelled = classified_units_labelled ["prediction" ]
134- assert np .all (predictions_labelled == [ "good" , "noise" , "good" , "noise" , "good" ] )
142+ assert np .all (predictions_labelled == expected_result_converted )
135143
136144
137145@pytest .mark .skip (reason = "We need to retrain the model to reflect any changes in metric computation" )
138- def test_exception_raised_when_metric_params_not_equal (sorting_analyzer_for_curation , trained_pipeline_path ):
146+ def test_exception_raised_when_metric_params_not_equal (sorting_analyzer_for_unitrefine_curation , trained_pipeline_path ):
139147 """We track whether the metric parameters used to compute the metrics used to train
140148 a model are the same as the parameters used to compute the metrics in the sorting
141149 analyzer which is being curated. If they are different, an error or warning will
142150 be raised depending on the `enforce_metric_params` kwarg. This behaviour is tested here."""
143151
144- sorting_analyzer_for_curation .compute (
145- "quality_metrics" , metric_names = ["num_spikes" , "snr" ], metric_params = {"snr" : {"peak_mode" : "peak_to_peak" }}
152+ sorting_analyzer_for_unitrefine_curation .compute (
153+ "quality_metrics" , metric_names = ["snr" ], metric_params = {"snr" : {"peak_mode" : "peak_to_peak" }}
154+ )
155+ sorting_analyzer_for_unitrefine_curation .compute (
156+ "template_metrics" , metric_names = ["half_width" , "peak_to_trough_duration" , "number_of_peaks" ]
146157 )
147- sorting_analyzer_for_curation .compute ("template_metrics" , metric_names = ["half_width" ])
148158
149159 model , model_info = load_model (model_folder = trained_pipeline_path , trusted = ["numpy.dtype" ])
150- model_based_classification = ModelBasedClassification (sorting_analyzer_for_curation , model )
160+ model_based_classification = ModelBasedClassification (sorting_analyzer_for_unitrefine_curation , model )
151161
152162 # an error should be raised if `enforce_metric_params` is True
153163 with pytest .raises (Exception ):
@@ -158,13 +168,15 @@ def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_cura
158168 model_based_classification ._check_params_for_classification (enforce_metric_params = False , model_info = model_info )
159169
160170 # Now test the positive case. Recompute using the default parameters
161- sorting_analyzer_for_curation .compute (
171+ sorting_analyzer_for_unitrefine_curation .compute (
162172 "quality_metrics" ,
163- metric_names = ["num_spikes" , " snr" ],
173+ metric_names = ["snr" ],
164174 metric_params = {"snr" : {"peak_sign" : "neg" , "peak_mode" : "extremum" }},
165175 )
166- sorting_analyzer_for_curation .compute ("template_metrics" , metric_names = ["half_width" ])
176+ sorting_analyzer_for_unitrefine_curation .compute (
177+ "template_metrics" , metric_names = ["half_width" , "peak_to_trough_duration" ]
178+ )
167179
168180 model , model_info = load_model (model_folder = trained_pipeline_path , trusted = ["numpy.dtype" ])
169- model_based_classification = ModelBasedClassification (sorting_analyzer_for_curation , model )
181+ model_based_classification = ModelBasedClassification (sorting_analyzer_for_unitrefine_curation , model )
170182 model_based_classification ._check_params_for_classification (enforce_metric_params = True , model_info = model_info )
0 commit comments