Skip to content

Commit 747ee14

Browse files
ecobostpre-commit-ci[bot]chrishalcrow
authored
Add peak_sign to sd_ratio (#4362)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com>
1 parent 23c96f9 commit 747ee14

1 file changed

Lines changed: 18 additions & 20 deletions

File tree

src/spikeinterface/metrics/quality/misc_metrics.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
import numpy as np
1616

1717
from spikeinterface.core.analyzer_extension_core import BaseMetric
18-
from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs
19-
from spikeinterface.core import SortingAnalyzer, get_noise_levels, NumpySorting
18+
from spikeinterface.core import SortingAnalyzer, NumpySorting
2019
from spikeinterface.core.template_tools import (
2120
get_template_extremum_channel,
2221
get_template_extremum_amplitude,
@@ -1359,7 +1358,8 @@ def compute_sd_ratio(
13591358
censored_period_ms: float = 4.0,
13601359
correct_for_drift: bool = True,
13611360
correct_for_template_itself: bool = True,
1362-
**kwargs,
1361+
peak_sign: str = "neg",
1362+
**job_kwargs,
13631363
):
13641364
"""
13651365
Computes the SD (Standard Deviation) of each unit's spike amplitudes, and compare it to the SD of noise.
@@ -1384,20 +1384,21 @@ def compute_sd_ratio(
13841384
correct_for_template_itself : bool, default: True
13851385
If true, will take into account that the template itself impacts the standard deviation of the noise,
13861386
and will make a rough estimation of what that impact is (and remove it).
1387-
**kwargs : dict, default: {}
1388-
Keyword arguments for computing spike amplitudes and extremum channel.
1387+
peak_sign : "neg" | "pos" | "both", default: "neg"
1388+
The peak sign used to select the template extremum channel.
1389+
**job_kwargs : dict, default: {}
1390+
Keyword arguments sent to get_noise_levels.
13891391
13901392
Returns
13911393
-------
1392-
num_spikes : dict
1393-
The number of spikes, across all segments, for each unit ID.
1394+
sd_ratio : dict
1395+
The ratio of the standard deviation of spike amplitudes to the standard deviation of noise, for each unit ID.
13941396
"""
13951397

13961398
from spikeinterface.curation.curation_tools import find_duplicated_spikes
1399+
from spikeinterface.core import get_noise_levels
13971400

13981401
check_has_required_extensions("sd_ratio", sorting_analyzer)
1399-
kwargs, job_kwargs = split_job_kwargs(kwargs)
1400-
job_kwargs = fix_job_kwargs(job_kwargs)
14011402

14021403
sorting = sorting_analyzer.sorting
14031404
sorting = sorting.select_periods(periods=periods)
@@ -1429,11 +1430,11 @@ def compute_sd_ratio(
14291430
noise_levels = get_noise_levels(
14301431
sorting_analyzer.recording, return_in_uV=sorting_analyzer.return_in_uV, method="std", **job_kwargs
14311432
)
1432-
best_channels = get_template_extremum_channel(sorting_analyzer, outputs="index", **kwargs)
1433-
n_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit(unit_ids=unit_ids)
1433+
best_channels = get_template_extremum_channel(sorting_analyzer, outputs="index", peak_sign=peak_sign)
14341434

14351435
if correct_for_template_itself:
1436-
tamplates_array = get_dense_templates_array(sorting_analyzer, return_in_uV=sorting_analyzer.return_in_uV)
1436+
n_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit(unit_ids=unit_ids)
1437+
templates_array = get_dense_templates_array(sorting_analyzer, return_in_uV=sorting_analyzer.return_in_uV)
14371438

14381439
sd_ratio = {}
14391440

@@ -1468,21 +1469,17 @@ def compute_sd_ratio(
14681469
best_channel = best_channels[unit_id]
14691470
std_noise = noise_levels[best_channel]
14701471

1471-
n_samples = sorting_analyzer.get_total_samples()
1472-
14731472
if correct_for_template_itself:
14741473
# template = sorting_analyzer.get_template(unit_id, force_dense=True)[:, best_channel]
14751474
unit_index = sorting.id_to_index(unit_id)
1476-
1477-
template = tamplates_array[unit_index, :, :][:, best_channel]
1478-
nsamples = template.shape[0]
1475+
template = templates_array[unit_index, :, best_channel]
14791476

14801477
# Computing the variance of a trace that is all 0 and n_spikes non-overlapping template.
14811478
# TODO: Take into account that templates for different segments might differ.
1482-
p = nsamples * n_spikes[unit_id] / n_samples
1483-
total_variance = p * np.mean(template**2) - p**2 * np.mean(template) ** 2
1479+
p = len(template) * n_spikes[unit_id] / sorting_analyzer.get_total_samples()
1480+
template_variance = p * np.mean(template**2) - p**2 * np.mean(template) ** 2
14841481

1485-
std_noise = np.sqrt(std_noise**2 - total_variance)
1482+
std_noise = np.sqrt(std_noise**2 - template_variance)
14861483

14871484
sd_ratio[unit_id] = unit_std / std_noise
14881485

@@ -1496,6 +1493,7 @@ class SDRatio(BaseMetric):
14961493
"censored_period_ms": 4.0,
14971494
"correct_for_drift": True,
14981495
"correct_for_template_itself": True,
1496+
"peak_sign": "neg",
14991497
}
15001498
metric_columns = {"sd_ratio": float}
15011499
metric_descriptions = {

0 commit comments

Comments
 (0)