diff --git a/doc/how_to/auto_label_units.rst b/doc/how_to/auto_label_units.rst index a241613845..96f4cac8f8 100644 --- a/doc/how_to/auto_label_units.rst +++ b/doc/how_to/auto_label_units.rst @@ -79,18 +79,18 @@ curation: 1. Quality-metrics based curation -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +---------------------------------- A simple solution is to use a filter based on quality metrics. To do so, -we can use the ``spikeinterface.curation.qualitymetrics_label_units`` +we can use the ``spikeinterface.curation.threshold_metrics_label_units`` function and provide a set of thresholds. .. code:: ipython3 qm_thresholds = { - "snr": {"min": 5}, - "firing_rate": {"min": 0.1, "max": 200}, - "rp_contamination": {"max": 0.5} + "snr": {"greater": 5}, + "firing_rate": {"greater": 0.1, "less": 200}, + "rp_contamination": {"less": 0.5} } .. code:: ipython3 @@ -143,7 +143,7 @@ across all units: .. image:: auto_label_units_files/auto_label_units_14_0.png -1. Bombcell +2. Bombcell ----------- **Bombcell** ([Fabre]_) is another threshold-based method that also uses @@ -161,24 +161,24 @@ file. .. parsed-literal:: - {'mua': {'amplitude_cutoff': {'max': 0.2, 'min': None}, - 'amplitude_median': {'max': None, 'min': 40}, - 'drift_ptp': {'max': 100, 'min': None}, - 'num_spikes': {'max': None, 'min': 300}, - 'presence_ratio': {'max': None, 'min': 0.7}, - 'rp_contamination': {'max': 0.1, 'min': None}, - 'snr': {'max': None, 'min': 5}}, - 'noise': {'exp_decay': {'max': 0.1, 'min': 0.01}, - 'num_negative_peaks': {'max': 1, 'min': None}, - 'num_positive_peaks': {'max': 2, 'min': None}, - 'peak_after_to_trough_ratio': {'max': 0.8, 'min': None}, - 'peak_to_trough_duration': {'max': 0.00115, 'min': 0.0001}, - 'waveform_baseline_flatness': {'max': 0.5, 'min': None}}, - 'non-somatic': {'main_peak_to_trough_ratio': {'max': 0.8, 'min': None}, - 'peak_before_to_peak_after_ratio': {'max': 3, 'min': None}, - 'peak_before_to_trough_ratio': {'max': 3, 'min': None}, - 'peak_before_width': {'max': None, 'min': 0.00015}, - 'trough_width': {'max': None, 'min': 0.0002}}} + {'mua': {'amplitude_cutoff': {'greater': None, 'less': 0.2}, + 'amplitude_median': {'abs': True, 'greater': 30, 'less': None}, + 'drift_ptp': {'greater': None, 'less': 100}, + 'num_spikes': {'greater': 300, 'less': None}, + 'presence_ratio': {'greater': 0.7, 'less': None}, + 'rp_contamination': {'greater': None, 'less': 0.1}, + 'snr': {'greater': 5, 'less': None}}, + 'noise': {'exp_decay': {'greater': 0.01, 'less': 0.1}, + 'num_negative_peaks': {'greater': None, 'less': 1}, + 'num_positive_peaks': {'greater': None, 'less': 2}, + 'peak_after_to_trough_ratio': {'greater': None, 'less': 0.8}, + 'peak_to_trough_duration': {'greater': 0.0001, 'less': 0.00115}, + 'waveform_baseline_flatness': {'greater': None, 'less': 0.5}}, + 'non-somatic': {'main_peak_to_trough_ratio': {'greater': None, 'less': 0.8}, + 'peak_before_to_peak_after_ratio': {'greater': None, 'less': 3}, + 'peak_before_to_trough_ratio': {'greater': None, 'less': 3}, + 'peak_before_width': {'greater': 0.00015, 'less': None}, + 'trough_width': {'greater': 0.0002, 'less': None}}} .. code:: ipython3 @@ -248,8 +248,8 @@ contamination (``rp_contamination``). .. image:: auto_label_units_files/auto_label_units_23_1.png -UnitRefine ----------- +3. UnitRefine +------------- **UnitRefine** ([Jain]_) also uses quality and template metrics, but in a different way. It uses pre-trained classifiers to trained on @@ -305,12 +305,11 @@ sorting with different strategies. We recommend running **Bombcell** and **UnitRefine** as part of your pipeline. These methods will facilitate further curation and make downstream analysis cleaner. -To remove units from your ``SortingAnalyzer``, you can simply use the -``select_units`` function: - Remove units from ``SortingAnalyzer`` ------------------------------------- +To remove units from your ``SortingAnalyzer``, you can use the ``select_units`` function. + After auto-labeling, we can easily remove the “noise” units for downstream analysis: diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index 5d2e8c968c..d9b877c9da 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -132,8 +132,8 @@ which applies a set of thresholds based on the available metrics (template/quali labels = threshold_metrics_label_units( sorting_analyzer=sorting_analyzer, thresholds={ - "snr": {"min": 5}, - "rp_contamination": {"max": 0.2}, + "snr": {"greater": 5}, + "rp_contamination": {"less": 0.2}, }, pass_label="good", fail_label="bad", diff --git a/examples/how_to/auto_label_units.py b/examples/how_to/auto_label_units.py index 4f5f14d569..a45a31f613 100644 --- a/examples/how_to/auto_label_units.py +++ b/examples/how_to/auto_label_units.py @@ -57,9 +57,9 @@ # %% qm_thresholds = { - "snr": {"min": 5}, - "firing_rate": {"min": 0.1, "max": 200}, - "rp_contamination": {"max": 0.5} + "snr": {"greater": 5}, + "firing_rate": {"greater": 0.1, "less": 200}, + "rp_contamination": {"less": 0.5} } # %% diff --git a/src/spikeinterface/curation/bombcell_curation.py b/src/spikeinterface/curation/bombcell_curation.py index ba5d6b51d0..888f5964ca 100644 --- a/src/spikeinterface/curation/bombcell_curation.py +++ b/src/spikeinterface/curation/bombcell_curation.py @@ -50,34 +50,34 @@ def bombcell_get_default_thresholds() -> dict: """ bombcell - Returns default thresholds for unit labeling. - Each metric has 'min' and 'max' values. Use None to disable a threshold (e.g. to ignore a metric completely - or to only have a min or a max threshold) + Each metric has 'greater' and 'less' values. Use None to disable a threshold (e.g. to ignore a metric completely + or to only have a greater or a less threshold) """ # bombcell return { "noise": { # failures -> NOISE - "num_positive_peaks": {"min": None, "max": 2}, - "num_negative_peaks": {"min": None, "max": 1}, - "peak_to_trough_duration": {"min": 0.0001, "max": 0.00115}, # seconds - "waveform_baseline_flatness": {"min": None, "max": 0.5}, - "peak_after_to_trough_ratio": {"min": None, "max": 0.8}, - "exp_decay": {"min": 0.01, "max": 0.1}, + "num_positive_peaks": {"greater": None, "less": 2}, + "num_negative_peaks": {"greater": None, "less": 1}, + "peak_to_trough_duration": {"greater": 0.0001, "less": 0.00115}, # seconds + "waveform_baseline_flatness": {"greater": None, "less": 0.5}, + "peak_after_to_trough_ratio": {"greater": None, "less": 0.8}, + "exp_decay": {"greater": 0.01, "less": 0.1}, }, "mua": { # failures -> MUA, only applied to units that passed noise thresholds - "amplitude_median": {"min": 30, "max": None, "abs": True}, # uV - "snr": {"min": 5, "max": None}, - "amplitude_cutoff": {"min": None, "max": 0.2}, - "num_spikes": {"min": 300, "max": None}, - "rp_contamination": {"min": None, "max": 0.1}, - "presence_ratio": {"min": 0.7, "max": None}, - "drift_ptp": {"min": None, "max": 100}, # um + "amplitude_median": {"greater": 30, "less": None, "abs": True}, # uV + "snr": {"greater": 5, "less": None}, + "amplitude_cutoff": {"greater": None, "less": 0.2}, + "num_spikes": {"greater": 300, "less": None}, + "rp_contamination": {"greater": None, "less": 0.1}, + "presence_ratio": {"greater": 0.7, "less": None}, + "drift_ptp": {"greater": None, "less": 100}, # um }, "non-somatic": { - "peak_before_to_trough_ratio": {"min": None, "max": 3}, - "peak_before_width": {"min": 0.00015, "max": None}, # seconds - "trough_width": {"min": 0.0002, "max": None}, # seconds - "peak_before_to_peak_after_ratio": {"min": None, "max": 3}, - "main_peak_to_trough_ratio": {"min": None, "max": 0.8}, + "peak_before_to_trough_ratio": {"greater": None, "less": 3}, + "peak_before_width": {"greater": 0.00015, "less": None}, # seconds + "trough_width": {"greater": 0.0002, "less": None}, # seconds + "peak_before_to_peak_after_ratio": {"greater": None, "less": 3}, + "main_peak_to_trough_ratio": {"greater": None, "less": 0.8}, }, } @@ -123,7 +123,7 @@ def bombcell_label_units( If provided, metrics are extracted automatically using get_metrics_extension_data(). thresholds : dict | str | Path | None Threshold dict or JSON file, including a three sections ("noise", "mua", "non-somatic") of - {"metric": {"min": val, "max": val}}. + {"metric": {"greater": val, "less": val}}. If None, default Bombcell thresholds are used. label_non_somatic : bool, default: True If True, detect non-somatic (dendritic, axonal) units. @@ -336,8 +336,8 @@ def save_bombcell_results( continue value = metrics.loc[unit_id, metric_name] thresh = flat_thresholds[metric_name] - thresh_min = thresh.get("min", None) - thresh_max = thresh.get("max", None) + thresh_min = thresh.get("greater", None) + thresh_max = thresh.get("less", None) # Determine pass/fail passed = True diff --git a/src/spikeinterface/curation/tests/test_bombcell_curation.py b/src/spikeinterface/curation/tests/test_bombcell_curation.py index 23f566ee9e..768d28c143 100644 --- a/src/spikeinterface/curation/tests/test_bombcell_curation.py +++ b/src/spikeinterface/curation/tests/test_bombcell_curation.py @@ -55,8 +55,8 @@ def test_bombcell_label_units_with_threshold_file(sorting_analyzer_with_metrics, # Define custom thresholds custom_thresholds = { - "snr": {"min": 5, "max": 100}, - "isi_violations": {"min": None, "max": 0.2}, + "snr": {"greater": 5, "less": 100}, + "isi_violations": {"greater": None, "less": 0.2}, } # Save thresholds to a temporary JSON file diff --git a/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py b/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py index 82e0400b29..658285d670 100644 --- a/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py +++ b/src/spikeinterface/curation/tests/test_threshold_metrics_curation.py @@ -17,8 +17,8 @@ def test_threshold_metrics_label_units_with_dataframe(): index=[0, 1, 2], ) thresholds = { - "snr": {"min": 5.0}, - "firing_rate": {"min": 0.1, "max": 20.0}, + "snr": {"greater": 5.0}, + "firing_rate": {"greater": 0.1, "less": 20.0}, } labels = threshold_metrics_label_units(metrics, thresholds) @@ -39,8 +39,8 @@ def test_threshold_metrics_label_units_with_file(tmp_path): index=[0, 1], ) thresholds = { - "snr": {"min": 5.0}, - "firing_rate": {"min": 0.1}, + "snr": {"greater": 5.0}, + "firing_rate": {"greater": 0.1}, } thresholds_file = tmp_path / "thresholds.json" @@ -63,8 +63,8 @@ def test_threshold_metrics_label_external_labels(): index=[0, 1], ) thresholds = { - "snr": {"min": 5.0}, - "firing_rate": {"min": 0.1}, + "snr": {"greater": 5.0}, + "firing_rate": {"greater": 0.1}, } labels = threshold_metrics_label_units( @@ -86,7 +86,7 @@ def test_threshold_metrics_label_units_operator_or_with_dataframe(): }, index=[0, 1, 2, 3], ) - thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} + thresholds = {"m1": {"greater": 0.0}, "m2": {"greater": 0.0}} labels_and = threshold_metrics_label_units( metrics, @@ -115,7 +115,7 @@ def test_threshold_metrics_label_units_nan_policy_fail_vs_ignore_and(): }, index=[10, 11, 12], ) - thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} + thresholds = {"m1": {"greater": 0.0}, "m2": {"greater": 0.0}} labels_fail = threshold_metrics_label_units( metrics, @@ -147,7 +147,7 @@ def test_threshold_metrics_label_units_nan_policy_ignore_with_or(): }, index=[20, 21], ) - thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} + thresholds = {"m1": {"greater": 0.0}, "m2": {"greater": 0.0}} labels_ignore_or = threshold_metrics_label_units( metrics, @@ -170,7 +170,7 @@ def test_threshold_metrics_label_units_nan_policy_pass_and_or(): }, index=[30, 31, 32, 33], ) - thresholds = {"m1": {"min": 0.0}, "m2": {"min": 0.0}} + thresholds = {"m1": {"greater": 0.0}, "m2": {"greater": 0.0}} labels_and = threshold_metrics_label_units( metrics, @@ -198,7 +198,7 @@ def test_threshold_metrics_label_units_invalid_operator_raises(): import pandas as pd metrics = pd.DataFrame({"m1": [1.0]}, index=[0]) - thresholds = {"m1": {"min": 0.0}} + thresholds = {"m1": {"greater": 0.0}} with pytest.raises(ValueError, match="operator must be 'and' or 'or'"): threshold_metrics_label_units(metrics, thresholds, operator="xor") @@ -207,7 +207,7 @@ def test_threshold_metrics_label_units_invalid_nan_policy_raises(): import pandas as pd metrics = pd.DataFrame({"m1": [1.0]}, index=[0]) - thresholds = {"m1": {"min": 0.0}} + thresholds = {"m1": {"greater": 0.0}} with pytest.raises(ValueError, match="nan_policy must be"): threshold_metrics_label_units(metrics, thresholds, nan_policy="omit") @@ -216,6 +216,15 @@ def test_threshold_metrics_label_units_missing_metric_raises(): import pandas as pd metrics = pd.DataFrame({"m1": [1.0]}, index=[0]) - thresholds = {"does_not_exist": {"min": 0.0}} + thresholds = {"does_not_exist": {"greater": 0.0}} with pytest.raises(ValueError, match="specified in thresholds are not present"): threshold_metrics_label_units(metrics, thresholds) + + +def test_threshold_metrics_label_units_invalid_threshold_keys_raises(): + import pandas as pd + + metrics = pd.DataFrame({"m1": [1.0]}, index=[0]) + thresholds = {"m1": {"greater": 0.0, "invalid_key": 1.0}} + with pytest.raises(ValueError, match="contains invalid keys"): + threshold_metrics_label_units(metrics, thresholds) diff --git a/src/spikeinterface/curation/threshold_metrics_curation.py b/src/spikeinterface/curation/threshold_metrics_curation.py index 8aff6cdfbe..b9b5ca1a86 100644 --- a/src/spikeinterface/curation/threshold_metrics_curation.py +++ b/src/spikeinterface/curation/threshold_metrics_curation.py @@ -26,8 +26,9 @@ def threshold_metrics_label_units( thresholds : dict | str | Path A dictionary or JSON file path where keys are metric names and values are threshold values for labeling units. Each key should correspond to a quality metric present in the analyzer's quality metrics DataFrame. Values - should contain at least "min" and/or "max" keys to specify threshold ranges. Optionally, an "abs": True entry - can be included to indicate that the metric should be treated as an absolute value when applying thresholds. + should contain at least "greater" and/or "less" keys to specify threshold ranges. Thresholds are inclusive, i.e. + "greater" is >= and "less" is <=. Optionally, an "abs": True entry can be included to indicate that the metric + should be treated as an absolute value when applying thresholds. pass_label : str, default: "good" The label to assign to units that pass all thresholds. fail_label : str, default: "noise" @@ -74,6 +75,14 @@ def threshold_metrics_label_units( f"Available metrics are: {metrics.columns.tolist()}" ) + # Check that threshold dictionaries contain only valid keys + valid_keys = {"greater", "less", "abs"} + for metric_name, threshold in thresholds_dict.items(): + if not set(threshold).issubset(valid_keys): + raise ValueError( + f"Threshold for metric '{metric_name}' contains invalid keys {set(threshold) - valid_keys}." + ) + if operator not in ("and", "or"): raise ValueError("operator must be 'and' or 'or'") @@ -88,8 +97,8 @@ def threshold_metrics_label_units( any_threshold_applied = False for metric_name, threshold in thresholds_dict.items(): - min_value = threshold.get("min", None) - max_value = threshold.get("max", None) + min_value = threshold.get("greater", None) + max_value = threshold.get("less", None) abs_value = threshold.get("abs", False) # If both disabled, ignore this metric diff --git a/src/spikeinterface/widgets/bombcell_curation.py b/src/spikeinterface/widgets/bombcell_curation.py index 4f120f4451..1a1212ba5c 100644 --- a/src/spikeinterface/widgets/bombcell_curation.py +++ b/src/spikeinterface/widgets/bombcell_curation.py @@ -30,7 +30,7 @@ class BombcellUpsetPlotWidget(BaseWidget): "non_soma", "non_soma_good", "non_soma_mua". thresholds : dict, optional Threshold dictionary with structure "noise", "mua", "non-somatic" as sections. Each section contains - metric names keys with "min" and "max" thresholds. + metric names keys with "greater" and "less" thresholds. If None, uses default thresholds. unit_labels_to_plot : list of str, optional List of unit labels to include in the plot. If None, defaults to all labels in thresholds. @@ -197,10 +197,10 @@ def _build_failure_table(self, metrics, thresholds): values = np.abs(values) failed = np.isnan(values) - if not is_threshold_disabled(thresh.get("min", None)): - failed |= values < thresh["min"] - if not is_threshold_disabled(thresh.get("max", None)): - failed |= values > thresh["max"] + if not is_threshold_disabled(thresh.get("greater", None)): + failed |= values < thresh["greater"] + if not is_threshold_disabled(thresh.get("less", None)): + failed |= values > thresh["less"] failure_data[metric_name] = failed return pd.DataFrame(failure_data, index=metrics.index) diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index 949887b0a0..38484b3533 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -334,7 +334,7 @@ class MetricsHistogramsWidget(BaseWidget): sorting_analyzer : SortingAnalyzer A SortingAnalyzer object with quality_metrics and/or template_metrics extensions computed. thresholds : dict, optional - Dictionary of metric thresholds. Can be a flat dict with metric names as keys and dicts with 'min' and/or 'max' + Dictionary of metric thresholds. Can be a flat dict with metric names as keys and dicts with 'greater' and/or 'less' as values, or a nested dict where top-level keys are different categories. Optionally, an "abs": True entry can be included in each metric's dict to indicate that the metric should be treated as an absolute value when applying thresholds. If None, default thresholds from `bombcell_get_default_thresholds` will be used. @@ -365,11 +365,11 @@ def __init__( assert isinstance(thresholds, dict), ( "Thresholds should be provided as a dictionary (optionally nested) with metric names as keys and dicts " - "with 'min' and/or 'max' as values." + "with 'greater' and/or 'less' as values." ) # Flatten thresholds for easier access (if subdicts are present). - # We check if all entries have a "min" or "max" key to determine if it's a nested dict of metrics or a flat dict. - if all(isinstance(value, dict) and ("min" in value or "max" in value) for value in thresholds.values()): + # We check if all entries have a "greater" or "less" key to determine if it's a nested dict of metrics or a flat dict. + if all(isinstance(value, dict) and ("greater" in value or "less" in value) for value in thresholds.values()): flat_thresholds = thresholds else: flat_thresholds = {} @@ -377,8 +377,8 @@ def __init__( assert isinstance(subdict, dict), "Each category in thresholds should be a dict of metric thresholds." for metric_name, thresh in subdict.items(): assert isinstance(thresh, dict) and ( - "min" in thresh or "max" in thresh - ), "Each threshold entry should be a dict with 'min' and/or 'max' keys." + "greater" in thresh or "less" in thresh + ), "Each threshold entry should be a dict with 'greater' and/or 'less' keys." flat_thresholds[metric_name] = thresh if metrics_to_plot is None: @@ -434,21 +434,15 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): thresh = thresholds.get(metric_name, {}) has_thresh = False - if not is_threshold_disabled(thresh.get("min", None)): - label = ( - f"min={int(thresh['min'])}" - if float(thresh["min"]).is_integer() - else f"min={float(thresh['min']):.2f}" - ) - ax.axvline(thresh["min"], color="red", ls="--", lw=2, label=label) + if not is_threshold_disabled(thresh.get("greater", None)): + value = float(thresh["greater"]) + label = f">={int(value)}" if value.is_integer() else f">={value:.2f}" + ax.axvline(value, color="red", ls="--", lw=2, label=label) has_thresh = True - if not is_threshold_disabled(thresh.get("max", None)): - label = ( - f"max={int(thresh['max'])}" - if float(thresh["max"]).is_integer() - else f"max={float(thresh['max']):.2f}" - ) - ax.axvline(thresh["max"], color="blue", ls="--", lw=2, label=label) + if not is_threshold_disabled(thresh.get("less", None)): + value = float(thresh["less"]) + label = f"<={int(value)}" if value.is_integer() else f"<={value:.2f}" + ax.axvline(value, color="blue", ls="--", lw=2, label=label) has_thresh = True ax.set_xlabel(metric_name) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 7f4ee8860a..c811c386e2 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -729,7 +729,7 @@ def test_plot_metric_histograms(self): possible_backends = list(sw.MetricsHistogramsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: - thresholds = {"snr": {"min": 5}, "isi_violation": {"max": 0.5}} + thresholds = {"snr": {"greater": 5}, "isi_violation": {"less": 0.5}} sw.plot_metric_histograms( self.sorting_analyzer_dense, thresholds=thresholds, backend=backend, **self.backend_kwargs[backend] )