Skip to content

Commit 1228d1d

Browse files
committed
Add optional 'abs' entry to threshold dict
1 parent 2ae7120 commit 1228d1d

3 files changed

Lines changed: 11 additions & 9 deletions

File tree

src/spikeinterface/curation/bombcell_curation.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def bombcell_get_default_thresholds() -> dict:
6666
"mua": { # failures -> MUA, only applied to units that passed noise thresholds
6767
"amplitude_median": {"min": 40, "max": None}, # uV
6868
"snr": {"min": 5, "max": None},
69-
"amplitude_cutoff": {"min": None, "max": 0.2},
69+
"amplitude_cutoff": {"min": None, "max": 0.2, "abs": True},
7070
"num_spikes": {"min": 300, "max": None},
7171
"rp_contamination": {"min": None, "max": 0.1},
7272
"presence_ratio": {"min": 0.7, "max": None},
@@ -174,10 +174,6 @@ def bombcell_label_units(
174174
raise ValueError("thresholds must be a dict, a JSON file path, or None")
175175

176176
n_units = len(combined_metrics)
177-
absolute_value_metrics = ["amplitude_median"]
178-
for metric in absolute_value_metrics:
179-
if metric in combined_metrics.columns:
180-
combined_metrics[metric] = np.abs(combined_metrics[metric])
181177

182178
noise_thresholds = thresholds_dict.get("noise", {})
183179
if len(noise_thresholds) > 0:

src/spikeinterface/curation/threshold_metrics_curation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ def threshold_metrics_label_units(
2626
thresholds : dict | str | Path
2727
A dictionary or JSON file path where keys are metric names and values are threshold values for labeling units.
2828
Each key should correspond to a quality metric present in the analyzer's quality metrics DataFrame. Values
29-
should contain at least "min" and/or "max" keys to specify threshold ranges.
29+
should contain at least "min" and/or "max" keys to specify threshold ranges. Optionally, an "abs": True entry
30+
can be included to indicate that the metric should be treated as an absolute value when applying thresholds.
3031
pass_label : str, default: "good"
3132
The label to assign to units that pass all thresholds.
3233
fail_label : str, default: "noise"
@@ -89,12 +90,15 @@ def threshold_metrics_label_units(
8990
for metric_name, threshold in thresholds_dict.items():
9091
min_value = threshold.get("min", None)
9192
max_value = threshold.get("max", None)
93+
abs_value = threshold.get("abs", False)
9294

9395
# If both disabled, ignore this metric
9496
if is_threshold_disabled(min_value) and is_threshold_disabled(max_value):
9597
continue
9698

9799
values = metrics[metric_name].to_numpy()
100+
if abs_value:
101+
values = np.abs(values)
98102
is_nan = np.isnan(values)
99103

100104
metric_ok = np.ones(len(values), dtype=bool)

src/spikeinterface/widgets/metrics.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,9 @@ class MetricsHistogramsWidget(BaseWidget):
335335
A SortingAnalyzer object with quality_metrics and/or template_metrics extensions computed.
336336
thresholds : dict, optional
337337
Dictionary of metric thresholds. Can be a flat dict with metric names as keys and dicts with 'min' and/or 'max'
338-
as values, or a nested dict where top-level keys are different categories.
338+
as values, or a nested dict where top-level keys are different categories. Optionally, an "abs": True entry
339+
can be included in each metric's dict to indicate that the metric should be treated as an absolute value when
340+
applying thresholds. If None, default thresholds from `bombcell_get_default_thresholds` will be used.
339341
metrics_to_plot : list, default: None
340342
List of metric names to plot. If None, all metrics with thresholds will be plotted.
341343
"""
@@ -412,15 +414,15 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
412414
self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)
413415

414416
colors = plt.cm.tab10(np.linspace(0, 1, 10))
415-
absolute_value_metrics = ["amplitude_median"]
416417

417418
axes = self.axes
418419
for idx, metric_name in enumerate(metrics_to_plot):
419420
row, col = idx // n_cols, idx % n_cols
420421
ax = axes[row, col]
421422

422423
values = metrics[metric_name].values
423-
if metric_name in absolute_value_metrics:
424+
abs_value = thresholds.get(metric_name, {}).get("abs", False)
425+
if abs_value:
424426
values = np.abs(values)
425427
values = values[~np.isnan(values) & ~np.isinf(values)]
426428

0 commit comments

Comments
 (0)