Skip to content

Commit 6a9966e

Browse files
authored
Improve robustness of model classifier tests (#4447)
1 parent 5d38178 commit 6a9966e

3 files changed

Lines changed: 100 additions & 66 deletions

File tree

src/spikeinterface/curation/tests/common.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer
3+
from spikeinterface.core import generate_ground_truth_recording, create_sorting_analyzer, aggregate_units
44
from spikeinterface.core.generate import inject_some_split_units
55
from spikeinterface.curation import train_model
66
from pathlib import Path
@@ -71,6 +71,19 @@ def sorting_analyzer_for_curation():
7171
return make_sorting_analyzer(sparse=True)
7272

7373

74+
@pytest.fixture(scope="module")
75+
def sorting_analyzer_for_unitrefine_curation():
76+
"""Makes an analyzer whose first 10 units are good normal units, and 10 which are noise. We make them
77+
noise by using a spike trains which are uncorrelated with the recording for `sorting2`."""
78+
79+
recording, sorting_1 = generate_ground_truth_recording(num_channels=4, seed=1, num_units=10)
80+
_, sorting_2 = generate_ground_truth_recording(num_channels=4, seed=2, num_units=10)
81+
both_sortings = aggregate_units([sorting_1, sorting_2])
82+
analyzer = create_sorting_analyzer(sorting=both_sortings, recording=recording)
83+
analyzer.compute(["random_spikes", "noise_levels", "templates"])
84+
return analyzer
85+
86+
7487
@pytest.fixture(scope="module")
7588
def sorting_analyzer_multi_segment_for_curation():
7689
return make_sorting_analyzer(sparse=True, durations=[50.0, 30.0])
@@ -83,7 +96,7 @@ def sorting_analyzer_with_splits():
8396

8497

8598
@pytest.fixture(scope="module")
86-
def trained_pipeline_path():
99+
def trained_pipeline_path(sorting_analyzer_for_unitrefine_curation):
87100
"""
88101
Makes a model saved at "./trained_pipeline" which will be used by other tests in the module.
89102
If the model already exists, this function does nothing.
@@ -92,20 +105,22 @@ def trained_pipeline_path():
92105
if trained_model_folder.is_dir():
93106
yield trained_model_folder
94107
else:
95-
analyzer = make_sorting_analyzer(sparse=True)
108+
analyzer = sorting_analyzer_for_unitrefine_curation
96109
analyzer.compute(
97110
{
98-
"quality_metrics": {"metric_names": ["snr", "num_spikes"]},
99-
"template_metrics": {"metric_names": ["half_width"]},
111+
"quality_metrics": {"metric_names": ["snr"]},
112+
"template_metrics": {"metric_names": ["half_width", "peak_to_trough_duration", "number_of_peaks"]},
100113
}
101114
)
102115
train_model(
103-
analyzers=[analyzer] * 5,
104-
labels=[[1, 0, 1, 0, 1]] * 5,
116+
analyzers=[analyzer],
105117
folder=trained_model_folder,
106-
classifiers=["RandomForestClassifier"],
118+
labels=[[1] * 10 + [0] * 10],
107119
imputation_strategies=["median"],
108120
scaling_techniques=["standard_scaler"],
121+
classifiers=["RandomForestClassifier"],
122+
overwrite=True,
123+
search_kwargs={"cv": 3, "scoring": "balanced_accuracy", "n_iter": 2},
109124
)
110125
yield trained_model_folder
111126

Lines changed: 52 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
from pathlib import Path
33

4-
from spikeinterface.curation.tests.common import sorting_analyzer_for_curation, trained_pipeline_path
4+
from spikeinterface.curation.tests.common import sorting_analyzer_for_unitrefine_curation, trained_pipeline_path
55
from spikeinterface.curation.model_based_curation import ModelBasedClassification
66
from spikeinterface.curation import model_based_label_units, load_model
77

@@ -16,9 +16,9 @@
1616

1717
@pytest.fixture
1818
def model(trained_pipeline_path):
19-
"""A toy model, created using the `sorting_analyzer_for_curation` from `spikeinterface.curation.tests.common`.
20-
It has been trained locally and, when applied to `sorting_analyzer_for_curation` will label its 5 units with
21-
the following labels: [1,0,1,0,1]."""
19+
"""A toy model, created using the `sorting_analyzer_for_unitrefine_curation` from `spikeinterface.curation.tests.common`.
20+
It has been trained locally and, when applied to `sorting_analyzer_for_unitrefine_curation` will label its 10 units with
21+
the following labels: [1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0]."""
2222

2323
model = load_model(trained_pipeline_path, trusted=["numpy.dtype"])
2424
return model
@@ -30,32 +30,36 @@ def required_metrics_and_columns():
3030
return ["num_spikes", "snr", "half_width"], ["num_spikes", "snr", "trough_half_width", "peak_half_width"]
3131

3232

33-
def test_model_based_classification_init(sorting_analyzer_for_curation, model):
33+
def test_model_based_classification_init(sorting_analyzer_for_unitrefine_curation, model):
3434
"""Test that the ModelBasedClassification attributes are correctly initialised"""
3535

36-
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0])
37-
assert model_based_classification.sorting_analyzer == sorting_analyzer_for_curation
36+
model_based_classification = ModelBasedClassification(sorting_analyzer_for_unitrefine_curation, model[0])
37+
assert model_based_classification.sorting_analyzer == sorting_analyzer_for_unitrefine_curation
3838
assert model_based_classification.pipeline == model[0]
3939
assert np.all(model_based_classification.required_metrics == model_based_classification.pipeline.feature_names_in_)
4040

4141

42-
def test_metric_ordering_independence(sorting_analyzer_for_curation, trained_pipeline_path):
42+
def test_metric_ordering_independence(sorting_analyzer_for_unitrefine_curation, trained_pipeline_path):
4343
"""The function `model_based_label_units` needs the correct metrics to have been computed. However,
4444
it should be independent of the order of computation. We test this here."""
4545

46-
sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"])
47-
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"])
46+
sorting_analyzer_for_unitrefine_curation.compute(
47+
"template_metrics", metric_names=["half_width", "peak_to_trough_duration", "number_of_peaks"]
48+
)
49+
sorting_analyzer_for_unitrefine_curation.compute("quality_metrics", metric_names=["snr"])
4850

4951
prediction_prob_dataframe_1 = model_based_label_units(
50-
sorting_analyzer=sorting_analyzer_for_curation,
52+
sorting_analyzer=sorting_analyzer_for_unitrefine_curation,
5153
model_folder=trained_pipeline_path,
5254
trusted=["numpy.dtype"],
5355
)
5456

55-
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["snr", "num_spikes"])
57+
sorting_analyzer_for_unitrefine_curation.compute(
58+
"template_metrics", metric_names=["peak_to_trough_duration", "half_width", "number_of_peaks"]
59+
)
5660

5761
prediction_prob_dataframe_2 = model_based_label_units(
58-
sorting_analyzer=sorting_analyzer_for_curation,
62+
sorting_analyzer=sorting_analyzer_for_unitrefine_curation,
5963
model_folder=trained_pipeline_path,
6064
trusted=["numpy.dtype"],
6165
)
@@ -64,40 +68,40 @@ def test_metric_ordering_independence(sorting_analyzer_for_curation, trained_pip
6468

6569

6670
def test_model_based_classification_get_metrics_for_classification(
67-
sorting_analyzer_for_curation, model, required_metrics_and_columns
71+
sorting_analyzer_for_unitrefine_curation, model, required_metrics_and_columns
6872
):
6973
"""If the user has not computed the required metrics, an error should be returned.
7074
This test checks that an error occurs when the required metrics have not been computed,
7175
and that no error is returned when the required metrics have been computed.
7276
"""
7377

74-
sorting_analyzer_for_curation.delete_extension("quality_metrics")
75-
sorting_analyzer_for_curation.delete_extension("template_metrics")
78+
sorting_analyzer_for_unitrefine_curation.delete_extension("quality_metrics")
79+
sorting_analyzer_for_unitrefine_curation.delete_extension("template_metrics")
7680

7781
required_metric_names, required_metric_columns = required_metrics_and_columns
7882

79-
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0])
83+
model_based_classification = ModelBasedClassification(sorting_analyzer_for_unitrefine_curation, model[0])
8084

8185
# Compute some (but not all) of the required metrics in sorting_analyzer, should still error
82-
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=[required_metric_names[0]])
83-
computed_metrics = sorting_analyzer_for_curation.get_metrics_extension_data()
86+
sorting_analyzer_for_unitrefine_curation.compute("quality_metrics", metric_names=[required_metric_names[0]])
87+
computed_metrics = sorting_analyzer_for_unitrefine_curation.get_metrics_extension_data()
8488
with pytest.raises(ValueError):
8589
model_based_classification._check_required_metrics_are_present(computed_metrics)
8690

8791
# Compute all of the required metrics in sorting_analyzer, no more error
88-
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=required_metric_names[0:2])
89-
sorting_analyzer_for_curation.compute("template_metrics", metric_names=[required_metric_names[2]])
92+
sorting_analyzer_for_unitrefine_curation.compute("quality_metrics", metric_names=required_metric_names[0:2])
93+
sorting_analyzer_for_unitrefine_curation.compute("template_metrics", metric_names=[required_metric_names[2]])
9094

91-
metrics_data = sorting_analyzer_for_curation.get_metrics_extension_data()
92-
assert metrics_data.shape[0] == len(sorting_analyzer_for_curation.sorting.get_unit_ids())
95+
metrics_data = sorting_analyzer_for_unitrefine_curation.get_metrics_extension_data()
96+
assert metrics_data.shape[0] == len(sorting_analyzer_for_unitrefine_curation.sorting.get_unit_ids())
9397
assert set(metrics_data.columns.to_list()) == set(required_metric_columns)
9498

9599

96-
def test_model_based_classification_export_to_phy(sorting_analyzer_for_curation, model):
100+
def test_model_based_classification_export_to_phy(sorting_analyzer_for_unitrefine_curation, model):
97101
import pandas as pd
98102

99103
# Test the _export_to_phy() method of ModelBasedClassification
100-
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0])
104+
model_based_classification = ModelBasedClassification(sorting_analyzer_for_unitrefine_curation, model[0])
101105

102106
classified_units = pd.DataFrame.from_dict({0: (1, 0.5), 1: (0, 0.5), 2: (1, 0.5), 3: (0, 0.5), 4: (1, 0.5)})
103107
# Function should fail here
@@ -112,42 +116,48 @@ def test_model_based_classification_export_to_phy(sorting_analyzer_for_curation,
112116
assert (phy_folder / "cluster_prediction.tsv").exists()
113117

114118

115-
def test_model_based_classification_predict_labels(sorting_analyzer_for_curation, model):
119+
def test_model_based_classification_predict_labels(sorting_analyzer_for_unitrefine_curation, model):
116120
"""The model `model` has been trained on the `sorting_analyzer` used in this test with
117121
the labels `[1, 0, 1, 0, 1]`. Hence if we apply the model to this `sorting_analyzer`
118122
we expect these labels to be outputted. The test checks this, and also checks
119123
that label conversion works as expected."""
120124

121-
sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"])
122-
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"])
125+
sorting_analyzer_for_unitrefine_curation.compute(
126+
"template_metrics", metric_names=["half_width", "peak_to_trough_duration", "number_of_peaks"]
127+
)
128+
sorting_analyzer_for_unitrefine_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"])
123129

124130
# Test the predict_labels() method of ModelBasedClassification
125-
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0])
131+
model_based_classification = ModelBasedClassification(sorting_analyzer_for_unitrefine_curation, model[0])
126132
classified_units = model_based_classification.predict_labels()
127133
predictions = classified_units["prediction"].values
128134

129-
assert np.all(predictions == np.array([1, 0, 1, 0, 1]))
135+
expected_result = np.array([1] * 10 + [0] * 10)
136+
assert np.all(predictions == expected_result)
130137

131138
conversion = {0: "noise", 1: "good"}
139+
expected_result_converted = np.array(["good"] * 10 + ["noise"] * 10)
132140
classified_units_labelled = model_based_classification.predict_labels(label_conversion=conversion)
133141
predictions_labelled = classified_units_labelled["prediction"]
134-
assert np.all(predictions_labelled == ["good", "noise", "good", "noise", "good"])
142+
assert np.all(predictions_labelled == expected_result_converted)
135143

136144

137145
@pytest.mark.skip(reason="We need to retrain the model to reflect any changes in metric computation")
138-
def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_curation, trained_pipeline_path):
146+
def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_unitrefine_curation, trained_pipeline_path):
139147
"""We track whether the metric parameters used to compute the metrics used to train
140148
a model are the same as the parameters used to compute the metrics in the sorting
141149
analyzer which is being curated. If they are different, an error or warning will
142150
be raised depending on the `enforce_metric_params` kwarg. This behaviour is tested here."""
143151

144-
sorting_analyzer_for_curation.compute(
145-
"quality_metrics", metric_names=["num_spikes", "snr"], metric_params={"snr": {"peak_mode": "peak_to_peak"}}
152+
sorting_analyzer_for_unitrefine_curation.compute(
153+
"quality_metrics", metric_names=["snr"], metric_params={"snr": {"peak_mode": "peak_to_peak"}}
154+
)
155+
sorting_analyzer_for_unitrefine_curation.compute(
156+
"template_metrics", metric_names=["half_width", "peak_to_trough_duration", "number_of_peaks"]
146157
)
147-
sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"])
148158

149159
model, model_info = load_model(model_folder=trained_pipeline_path, trusted=["numpy.dtype"])
150-
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model)
160+
model_based_classification = ModelBasedClassification(sorting_analyzer_for_unitrefine_curation, model)
151161

152162
# an error should be raised if `enforce_metric_params` is True
153163
with pytest.raises(Exception):
@@ -158,13 +168,15 @@ def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_cura
158168
model_based_classification._check_params_for_classification(enforce_metric_params=False, model_info=model_info)
159169

160170
# Now test the positive case. Recompute using the default parameters
161-
sorting_analyzer_for_curation.compute(
171+
sorting_analyzer_for_unitrefine_curation.compute(
162172
"quality_metrics",
163-
metric_names=["num_spikes", "snr"],
173+
metric_names=["snr"],
164174
metric_params={"snr": {"peak_sign": "neg", "peak_mode": "extremum"}},
165175
)
166-
sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"])
176+
sorting_analyzer_for_unitrefine_curation.compute(
177+
"template_metrics", metric_names=["half_width", "peak_to_trough_duration"]
178+
)
167179

168180
model, model_info = load_model(model_folder=trained_pipeline_path, trusted=["numpy.dtype"])
169-
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model)
181+
model_based_classification = ModelBasedClassification(sorting_analyzer_for_unitrefine_curation, model)
170182
model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info)
Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,84 +1,91 @@
11
import pytest
22

3-
from spikeinterface.curation.tests.common import sorting_analyzer_for_curation, trained_pipeline_path
3+
from spikeinterface.curation.tests.common import sorting_analyzer_for_unitrefine_curation, trained_pipeline_path
44
from spikeinterface.curation import unitrefine_label_units
55

66

7-
def test_unitrefine_label_units_hf(sorting_analyzer_for_curation):
7+
def test_unitrefine_label_units_hf(sorting_analyzer_for_unitrefine_curation):
88
"""Test the `unitrefine_label_units` function."""
9-
sorting_analyzer_for_curation.compute("template_metrics", include_multi_channel_metrics=True)
10-
sorting_analyzer_for_curation.compute("quality_metrics")
9+
sorting_analyzer_for_unitrefine_curation.compute(
10+
{
11+
"spike_amplitudes": {},
12+
"template_metrics": {"include_multi_channel_metrics": True},
13+
"quality_metrics": {},
14+
}
15+
)
1116

1217
# test passing both classifiers
1318
labels = unitrefine_label_units(
14-
sorting_analyzer_for_curation,
19+
sorting_analyzer_for_unitrefine_curation,
1520
noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight",
1621
sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight",
1722
)
1823

1924
assert "unitrefine_label" in labels.columns
2025
assert "unitrefine_probability" in labels.columns
21-
assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids)
26+
assert labels.shape[0] == len(sorting_analyzer_for_unitrefine_curation.sorting.unit_ids)
2227

2328
# test only noise neural classifier
2429
labels = unitrefine_label_units(
25-
sorting_analyzer_for_curation,
30+
sorting_analyzer_for_unitrefine_curation,
2631
noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight",
2732
sua_mua_classifier=None,
2833
)
2934

3035
assert "unitrefine_label" in labels.columns
3136
assert "unitrefine_probability" in labels.columns
32-
assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids)
37+
assert labels.shape[0] == len(sorting_analyzer_for_unitrefine_curation.sorting.unit_ids)
3338

3439
# test only sua mua classifier
3540
labels = unitrefine_label_units(
36-
sorting_analyzer_for_curation,
41+
sorting_analyzer_for_unitrefine_curation,
3742
noise_neural_classifier=None,
3843
sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight",
3944
)
4045

4146
assert "unitrefine_label" in labels.columns
4247
assert "unitrefine_probability" in labels.columns
43-
assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids)
48+
assert labels.shape[0] == len(sorting_analyzer_for_unitrefine_curation.sorting.unit_ids)
4449

4550
# test passing none
4651
with pytest.raises(ValueError):
4752
labels = unitrefine_label_units(
48-
sorting_analyzer_for_curation,
53+
sorting_analyzer_for_unitrefine_curation,
4954
noise_neural_classifier=None,
5055
sua_mua_classifier=None,
5156
)
5257

5358
# test warnings when unexpected labels are returned
5459
with pytest.warns(UserWarning):
5560
labels = unitrefine_label_units(
56-
sorting_analyzer_for_curation,
61+
sorting_analyzer_for_unitrefine_curation,
5762
noise_neural_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight",
5863
sua_mua_classifier=None,
5964
)
6065

6166
with pytest.warns(UserWarning):
6267
labels = unitrefine_label_units(
63-
sorting_analyzer_for_curation,
68+
sorting_analyzer_for_unitrefine_curation,
6469
noise_neural_classifier=None,
6570
sua_mua_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight",
6671
)
6772

6873

69-
def test_unitrefine_label_units_with_local_models(sorting_analyzer_for_curation, trained_pipeline_path):
74+
def test_unitrefine_label_units_with_local_models(sorting_analyzer_for_unitrefine_curation, trained_pipeline_path):
7075
# test with trained local models
71-
sorting_analyzer_for_curation.compute("template_metrics", include_multi_channel_metrics=True)
72-
sorting_analyzer_for_curation.compute("quality_metrics")
76+
sorting_analyzer_for_unitrefine_curation.compute(
77+
"template_metrics", metric_names=["half_width", "peak_to_trough_duration", "number_of_peaks"]
78+
)
79+
sorting_analyzer_for_unitrefine_curation.compute("quality_metrics")
7380

7481
# test passing model folder
7582
labels = unitrefine_label_units(
76-
sorting_analyzer_for_curation,
83+
sorting_analyzer_for_unitrefine_curation,
7784
noise_neural_classifier=trained_pipeline_path,
7885
)
7986

8087
# test passing model folder
8188
labels = unitrefine_label_units(
82-
sorting_analyzer_for_curation,
89+
sorting_analyzer_for_unitrefine_curation,
8390
noise_neural_classifier=trained_pipeline_path / "best_model.skops",
8491
)

0 commit comments

Comments
 (0)