Skip to content

Commit 924020e

Browse files
ecobostpre-commit-ci[bot]alejoe91
authored
Change 'min'/'max' to 'greater'/'less' when defining thresholds for threshold_metrics_label_units (#4416)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Alessio Buccino <alejoe9187@gmail.com>
1 parent 4ce1622 commit 924020e

File tree

10 files changed

+113
-102
lines changed

10 files changed

+113
-102
lines changed

doc/how_to/auto_label_units.rst

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,18 @@ curation:
7979
8080
8181
1. Quality-metrics based curation
82-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
82+
----------------------------------
8383

8484
A simple solution is to use a filter based on quality metrics. To do so,
85-
we can use the ``spikeinterface.curation.qualitymetrics_label_units``
85+
we can use the ``spikeinterface.curation.threshold_metrics_label_units``
8686
function and provide a set of thresholds.
8787

8888
.. code:: ipython3
8989
9090
qm_thresholds = {
91-
"snr": {"min": 5},
92-
"firing_rate": {"min": 0.1, "max": 200},
93-
"rp_contamination": {"max": 0.5}
91+
"snr": {"greater": 5},
92+
"firing_rate": {"greater": 0.1, "less": 200},
93+
"rp_contamination": {"less": 0.5}
9494
}
9595
9696
.. code:: ipython3
@@ -143,7 +143,7 @@ across all units:
143143
.. image:: auto_label_units_files/auto_label_units_14_0.png
144144

145145

146-
1. Bombcell
146+
2. Bombcell
147147
-----------
148148

149149
**Bombcell** ([Fabre]_) is another threshold-based method that also uses
@@ -161,24 +161,24 @@ file.
161161
162162
.. parsed-literal::
163163
164-
{'mua': {'amplitude_cutoff': {'max': 0.2, 'min': None},
165-
'amplitude_median': {'max': None, 'min': 40},
166-
'drift_ptp': {'max': 100, 'min': None},
167-
'num_spikes': {'max': None, 'min': 300},
168-
'presence_ratio': {'max': None, 'min': 0.7},
169-
'rp_contamination': {'max': 0.1, 'min': None},
170-
'snr': {'max': None, 'min': 5}},
171-
'noise': {'exp_decay': {'max': 0.1, 'min': 0.01},
172-
'num_negative_peaks': {'max': 1, 'min': None},
173-
'num_positive_peaks': {'max': 2, 'min': None},
174-
'peak_after_to_trough_ratio': {'max': 0.8, 'min': None},
175-
'peak_to_trough_duration': {'max': 0.00115, 'min': 0.0001},
176-
'waveform_baseline_flatness': {'max': 0.5, 'min': None}},
177-
'non-somatic': {'main_peak_to_trough_ratio': {'max': 0.8, 'min': None},
178-
'peak_before_to_peak_after_ratio': {'max': 3, 'min': None},
179-
'peak_before_to_trough_ratio': {'max': 3, 'min': None},
180-
'peak_before_width': {'max': None, 'min': 0.00015},
181-
'trough_width': {'max': None, 'min': 0.0002}}}
164+
{'mua': {'amplitude_cutoff': {'greater': None, 'less': 0.2},
165+
'amplitude_median': {'abs': True, 'greater': 30, 'less': None},
166+
'drift_ptp': {'greater': None, 'less': 100},
167+
'num_spikes': {'greater': 300, 'less': None},
168+
'presence_ratio': {'greater': 0.7, 'less': None},
169+
'rp_contamination': {'greater': None, 'less': 0.1},
170+
'snr': {'greater': 5, 'less': None}},
171+
'noise': {'exp_decay': {'greater': 0.01, 'less': 0.1},
172+
'num_negative_peaks': {'greater': None, 'less': 1},
173+
'num_positive_peaks': {'greater': None, 'less': 2},
174+
'peak_after_to_trough_ratio': {'greater': None, 'less': 0.8},
175+
'peak_to_trough_duration': {'greater': 0.0001, 'less': 0.00115},
176+
'waveform_baseline_flatness': {'greater': None, 'less': 0.5}},
177+
'non-somatic': {'main_peak_to_trough_ratio': {'greater': None, 'less': 0.8},
178+
'peak_before_to_peak_after_ratio': {'greater': None, 'less': 3},
179+
'peak_before_to_trough_ratio': {'greater': None, 'less': 3},
180+
'peak_before_width': {'greater': 0.00015, 'less': None},
181+
'trough_width': {'greater': 0.0002, 'less': None}}}
182182
183183
184184
.. code:: ipython3
@@ -248,8 +248,8 @@ contamination (``rp_contamination``).
248248
.. image:: auto_label_units_files/auto_label_units_23_1.png
249249

250250

251-
UnitRefine
252-
----------
251+
3. UnitRefine
252+
-------------
253253

254254
**UnitRefine** ([Jain]_) also uses quality and template metrics, but in
255255
a different way. It uses pre-trained classifiers to trained on
@@ -305,12 +305,11 @@ sorting with different strategies. We recommend running **Bombcell** and
305305
**UnitRefine** as part of your pipeline. These methods will facilitate
306306
further curation and make downstream analysis cleaner.
307307

308-
To remove units from your ``SortingAnalyzer``, you can simply use the
309-
``select_units`` function:
310-
311308
Remove units from ``SortingAnalyzer``
312309
-------------------------------------
313310

311+
To remove units from your ``SortingAnalyzer``, you can use the ``select_units`` function.
312+
314313
After auto-labeling, we can easily remove the “noise” units for
315314
downstream analysis:
316315

doc/modules/curation.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ which applies a set of thresholds based on the available metrics (template/quali
132132
labels = threshold_metrics_label_units(
133133
sorting_analyzer=sorting_analyzer,
134134
thresholds={
135-
"snr": {"min": 5},
136-
"rp_contamination": {"max": 0.2},
135+
"snr": {"greater": 5},
136+
"rp_contamination": {"less": 0.2},
137137
},
138138
pass_label="good",
139139
fail_label="bad",

examples/how_to/auto_label_units.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@
5757

5858
# %%
5959
qm_thresholds = {
60-
"snr": {"min": 5},
61-
"firing_rate": {"min": 0.1, "max": 200},
62-
"rp_contamination": {"max": 0.5}
60+
"snr": {"greater": 5},
61+
"firing_rate": {"greater": 0.1, "less": 200},
62+
"rp_contamination": {"less": 0.5}
6363
}
6464

6565
# %%

src/spikeinterface/curation/bombcell_curation.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -50,34 +50,34 @@ def bombcell_get_default_thresholds() -> dict:
5050
"""
5151
bombcell - Returns default thresholds for unit labeling.
5252
53-
Each metric has 'min' and 'max' values. Use None to disable a threshold (e.g. to ignore a metric completely
54-
or to only have a min or a max threshold)
53+
Each metric has 'greater' and 'less' values. Use None to disable a threshold (e.g. to ignore a metric completely
54+
or to only have a greater or a less threshold)
5555
"""
5656
# bombcell
5757
return {
5858
"noise": { # failures -> NOISE
59-
"num_positive_peaks": {"min": None, "max": 2},
60-
"num_negative_peaks": {"min": None, "max": 1},
61-
"peak_to_trough_duration": {"min": 0.0001, "max": 0.00115}, # seconds
62-
"waveform_baseline_flatness": {"min": None, "max": 0.5},
63-
"peak_after_to_trough_ratio": {"min": None, "max": 0.8},
64-
"exp_decay": {"min": 0.01, "max": 0.1},
59+
"num_positive_peaks": {"greater": None, "less": 2},
60+
"num_negative_peaks": {"greater": None, "less": 1},
61+
"peak_to_trough_duration": {"greater": 0.0001, "less": 0.00115}, # seconds
62+
"waveform_baseline_flatness": {"greater": None, "less": 0.5},
63+
"peak_after_to_trough_ratio": {"greater": None, "less": 0.8},
64+
"exp_decay": {"greater": 0.01, "less": 0.1},
6565
},
6666
"mua": { # failures -> MUA, only applied to units that passed noise thresholds
67-
"amplitude_median": {"min": 30, "max": None, "abs": True}, # uV
68-
"snr": {"min": 5, "max": None},
69-
"amplitude_cutoff": {"min": None, "max": 0.2},
70-
"num_spikes": {"min": 300, "max": None},
71-
"rp_contamination": {"min": None, "max": 0.1},
72-
"presence_ratio": {"min": 0.7, "max": None},
73-
"drift_ptp": {"min": None, "max": 100}, # um
67+
"amplitude_median": {"greater": 30, "less": None, "abs": True}, # uV
68+
"snr": {"greater": 5, "less": None},
69+
"amplitude_cutoff": {"greater": None, "less": 0.2},
70+
"num_spikes": {"greater": 300, "less": None},
71+
"rp_contamination": {"greater": None, "less": 0.1},
72+
"presence_ratio": {"greater": 0.7, "less": None},
73+
"drift_ptp": {"greater": None, "less": 100}, # um
7474
},
7575
"non-somatic": {
76-
"peak_before_to_trough_ratio": {"min": None, "max": 3},
77-
"peak_before_width": {"min": 0.00015, "max": None}, # seconds
78-
"trough_width": {"min": 0.0002, "max": None}, # seconds
79-
"peak_before_to_peak_after_ratio": {"min": None, "max": 3},
80-
"main_peak_to_trough_ratio": {"min": None, "max": 0.8},
76+
"peak_before_to_trough_ratio": {"greater": None, "less": 3},
77+
"peak_before_width": {"greater": 0.00015, "less": None}, # seconds
78+
"trough_width": {"greater": 0.0002, "less": None}, # seconds
79+
"peak_before_to_peak_after_ratio": {"greater": None, "less": 3},
80+
"main_peak_to_trough_ratio": {"greater": None, "less": 0.8},
8181
},
8282
}
8383

@@ -123,7 +123,7 @@ def bombcell_label_units(
123123
If provided, metrics are extracted automatically using get_metrics_extension_data().
124124
thresholds : dict | str | Path | None
125125
Threshold dict or JSON file, including a three sections ("noise", "mua", "non-somatic") of
126-
{"metric": {"min": val, "max": val}}.
126+
{"metric": {"greater": val, "less": val}}.
127127
If None, default Bombcell thresholds are used.
128128
label_non_somatic : bool, default: True
129129
If True, detect non-somatic (dendritic, axonal) units.
@@ -336,8 +336,8 @@ def save_bombcell_results(
336336
continue
337337
value = metrics.loc[unit_id, metric_name]
338338
thresh = flat_thresholds[metric_name]
339-
thresh_min = thresh.get("min", None)
340-
thresh_max = thresh.get("max", None)
339+
thresh_min = thresh.get("greater", None)
340+
thresh_max = thresh.get("less", None)
341341

342342
# Determine pass/fail
343343
passed = True

src/spikeinterface/curation/tests/test_bombcell_curation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def test_bombcell_label_units_with_threshold_file(sorting_analyzer_with_metrics,
5555

5656
# Define custom thresholds
5757
custom_thresholds = {
58-
"snr": {"min": 5, "max": 100},
59-
"isi_violations": {"min": None, "max": 0.2},
58+
"snr": {"greater": 5, "less": 100},
59+
"isi_violations": {"greater": None, "less": 0.2},
6060
}
6161

6262
# Save thresholds to a temporary JSON file

src/spikeinterface/curation/tests/test_threshold_metrics_curation.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def test_threshold_metrics_label_units_with_dataframe():
1717
index=[0, 1, 2],
1818
)
1919
thresholds = {
20-
"snr": {"min": 5.0},
21-
"firing_rate": {"min": 0.1, "max": 20.0},
20+
"snr": {"greater": 5.0},
21+
"firing_rate": {"greater": 0.1, "less": 20.0},
2222
}
2323

2424
labels = threshold_metrics_label_units(metrics, thresholds)
@@ -39,8 +39,8 @@ def test_threshold_metrics_label_units_with_file(tmp_path):
3939
index=[0, 1],
4040
)
4141
thresholds = {
42-
"snr": {"min": 5.0},
43-
"firing_rate": {"min": 0.1},
42+
"snr": {"greater": 5.0},
43+
"firing_rate": {"greater": 0.1},
4444
}
4545

4646
thresholds_file = tmp_path / "thresholds.json"
@@ -63,8 +63,8 @@ def test_threshold_metrics_label_external_labels():
6363
index=[0, 1],
6464
)
6565
thresholds = {
66-
"snr": {"min": 5.0},
67-
"firing_rate": {"min": 0.1},
66+
"snr": {"greater": 5.0},
67+
"firing_rate": {"greater": 0.1},
6868
}
6969

7070
labels = threshold_metrics_label_units(
@@ -86,7 +86,7 @@ def test_threshold_metrics_label_units_operator_or_with_dataframe():
8686
},
8787
index=[0, 1, 2, 3],
8888
)
89-
thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}}
89+
thresholds = {"m1": {"greater": 0.0}, "m2": {"greater": 0.0}}
9090

9191
labels_and = threshold_metrics_label_units(
9292
metrics,
@@ -115,7 +115,7 @@ def test_threshold_metrics_label_units_nan_policy_fail_vs_ignore_and():
115115
},
116116
index=[10, 11, 12],
117117
)
118-
thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}}
118+
thresholds = {"m1": {"greater": 0.0}, "m2": {"greater": 0.0}}
119119

120120
labels_fail = threshold_metrics_label_units(
121121
metrics,
@@ -147,7 +147,7 @@ def test_threshold_metrics_label_units_nan_policy_ignore_with_or():
147147
},
148148
index=[20, 21],
149149
)
150-
thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}}
150+
thresholds = {"m1": {"greater": 0.0}, "m2": {"greater": 0.0}}
151151

152152
labels_ignore_or = threshold_metrics_label_units(
153153
metrics,
@@ -170,7 +170,7 @@ def test_threshold_metrics_label_units_nan_policy_pass_and_or():
170170
},
171171
index=[30, 31, 32, 33],
172172
)
173-
thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}}
173+
thresholds = {"m1": {"greater": 0.0}, "m2": {"greater": 0.0}}
174174

175175
labels_and = threshold_metrics_label_units(
176176
metrics,
@@ -198,7 +198,7 @@ def test_threshold_metrics_label_units_invalid_operator_raises():
198198
import pandas as pd
199199

200200
metrics = pd.DataFrame({"m1": [1.0]}, index=[0])
201-
thresholds = {"m1": {"min": 0.0}}
201+
thresholds = {"m1": {"greater": 0.0}}
202202
with pytest.raises(ValueError, match="operator must be 'and' or 'or'"):
203203
threshold_metrics_label_units(metrics, thresholds, operator="xor")
204204

@@ -207,7 +207,7 @@ def test_threshold_metrics_label_units_invalid_nan_policy_raises():
207207
import pandas as pd
208208

209209
metrics = pd.DataFrame({"m1": [1.0]}, index=[0])
210-
thresholds = {"m1": {"min": 0.0}}
210+
thresholds = {"m1": {"greater": 0.0}}
211211
with pytest.raises(ValueError, match="nan_policy must be"):
212212
threshold_metrics_label_units(metrics, thresholds, nan_policy="omit")
213213

@@ -216,6 +216,15 @@ def test_threshold_metrics_label_units_missing_metric_raises():
216216
import pandas as pd
217217

218218
metrics = pd.DataFrame({"m1": [1.0]}, index=[0])
219-
thresholds = {"does_not_exist": {"min": 0.0}}
219+
thresholds = {"does_not_exist": {"greater": 0.0}}
220220
with pytest.raises(ValueError, match="specified in thresholds are not present"):
221221
threshold_metrics_label_units(metrics, thresholds)
222+
223+
224+
def test_threshold_metrics_label_units_invalid_threshold_keys_raises():
225+
import pandas as pd
226+
227+
metrics = pd.DataFrame({"m1": [1.0]}, index=[0])
228+
thresholds = {"m1": {"greater": 0.0, "invalid_key": 1.0}}
229+
with pytest.raises(ValueError, match="contains invalid keys"):
230+
threshold_metrics_label_units(metrics, thresholds)

src/spikeinterface/curation/threshold_metrics_curation.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ def threshold_metrics_label_units(
2626
thresholds : dict | str | Path
2727
A dictionary or JSON file path where keys are metric names and values are threshold values for labeling units.
2828
Each key should correspond to a quality metric present in the analyzer's quality metrics DataFrame. Values
29-
should contain at least "min" and/or "max" keys to specify threshold ranges. Optionally, an "abs": True entry
30-
can be included to indicate that the metric should be treated as an absolute value when applying thresholds.
29+
should contain at least "greater" and/or "less" keys to specify threshold ranges. Thresholds are inclusive, i.e.
30+
"greater" is >= and "less" is <=. Optionally, an "abs": True entry can be included to indicate that the metric
31+
should be treated as an absolute value when applying thresholds.
3132
pass_label : str, default: "good"
3233
The label to assign to units that pass all thresholds.
3334
fail_label : str, default: "noise"
@@ -74,6 +75,14 @@ def threshold_metrics_label_units(
7475
f"Available metrics are: {metrics.columns.tolist()}"
7576
)
7677

78+
# Check that threshold dictionaries contain only valid keys
79+
valid_keys = {"greater", "less", "abs"}
80+
for metric_name, threshold in thresholds_dict.items():
81+
if not set(threshold).issubset(valid_keys):
82+
raise ValueError(
83+
f"Threshold for metric '{metric_name}' contains invalid keys {set(threshold) - valid_keys}."
84+
)
85+
7786
if operator not in ("and", "or"):
7887
raise ValueError("operator must be 'and' or 'or'")
7988

@@ -88,8 +97,8 @@ def threshold_metrics_label_units(
8897
any_threshold_applied = False
8998

9099
for metric_name, threshold in thresholds_dict.items():
91-
min_value = threshold.get("min", None)
92-
max_value = threshold.get("max", None)
100+
min_value = threshold.get("greater", None)
101+
max_value = threshold.get("less", None)
93102
abs_value = threshold.get("abs", False)
94103

95104
# If both disabled, ignore this metric

src/spikeinterface/widgets/bombcell_curation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class BombcellUpsetPlotWidget(BaseWidget):
3030
"non_soma", "non_soma_good", "non_soma_mua".
3131
thresholds : dict, optional
3232
Threshold dictionary with structure "noise", "mua", "non-somatic" as sections. Each section contains
33-
metric names keys with "min" and "max" thresholds.
33+
metric names keys with "greater" and "less" thresholds.
3434
If None, uses default thresholds.
3535
unit_labels_to_plot : list of str, optional
3636
List of unit labels to include in the plot. If None, defaults to all labels in thresholds.
@@ -197,10 +197,10 @@ def _build_failure_table(self, metrics, thresholds):
197197
values = np.abs(values)
198198

199199
failed = np.isnan(values)
200-
if not is_threshold_disabled(thresh.get("min", None)):
201-
failed |= values < thresh["min"]
202-
if not is_threshold_disabled(thresh.get("max", None)):
203-
failed |= values > thresh["max"]
200+
if not is_threshold_disabled(thresh.get("greater", None)):
201+
failed |= values < thresh["greater"]
202+
if not is_threshold_disabled(thresh.get("less", None)):
203+
failed |= values > thresh["less"]
204204
failure_data[metric_name] = failed
205205

206206
return pd.DataFrame(failure_data, index=metrics.index)

0 commit comments

Comments
 (0)