1515import numpy as np
1616
1717from spikeinterface .core .analyzer_extension_core import BaseMetric
18- from spikeinterface .core .job_tools import fix_job_kwargs , split_job_kwargs
19- from spikeinterface .core import SortingAnalyzer , get_noise_levels , NumpySorting
18+ from spikeinterface .core import SortingAnalyzer , NumpySorting
2019from spikeinterface .core .template_tools import (
2120 get_template_extremum_channel ,
2221 get_template_extremum_amplitude ,
@@ -1359,7 +1358,8 @@ def compute_sd_ratio(
13591358 censored_period_ms : float = 4.0 ,
13601359 correct_for_drift : bool = True ,
13611360 correct_for_template_itself : bool = True ,
1362- ** kwargs ,
1361+ peak_sign : str = "neg" ,
1362+ ** job_kwargs ,
13631363):
13641364 """
13651365 Computes the SD (Standard Deviation) of each unit's spike amplitudes, and compare it to the SD of noise.
@@ -1384,20 +1384,21 @@ def compute_sd_ratio(
13841384 correct_for_template_itself : bool, default: True
13851385 If true, will take into account that the template itself impacts the standard deviation of the noise,
13861386 and will make a rough estimation of what that impact is (and remove it).
1387- **kwargs : dict, default: {}
1388- Keyword arguments for computing spike amplitudes and extremum channel.
1387+ peak_sign : "neg" | "pos" | "both", default: "neg"
1388+ The peak sign used to select the template extremum channel.
1389+ **job_kwargs : dict, default: {}
1390+ Keyword arguments sent to get_noise_levels.
13891391
13901392 Returns
13911393 -------
1392- num_spikes : dict
1393- The number of spikes, across all segments , for each unit ID.
1394+ sd_ratio : dict
1395+ The ratio of the standard deviation of spike amplitudes to the standard deviation of noise , for each unit ID.
13941396 """
13951397
13961398 from spikeinterface .curation .curation_tools import find_duplicated_spikes
1399+ from spikeinterface .core import get_noise_levels
13971400
13981401 check_has_required_extensions ("sd_ratio" , sorting_analyzer )
1399- kwargs , job_kwargs = split_job_kwargs (kwargs )
1400- job_kwargs = fix_job_kwargs (job_kwargs )
14011402
14021403 sorting = sorting_analyzer .sorting
14031404 sorting = sorting .select_periods (periods = periods )
@@ -1429,11 +1430,11 @@ def compute_sd_ratio(
14291430 noise_levels = get_noise_levels (
14301431 sorting_analyzer .recording , return_in_uV = sorting_analyzer .return_in_uV , method = "std" , ** job_kwargs
14311432 )
1432- best_channels = get_template_extremum_channel (sorting_analyzer , outputs = "index" , ** kwargs )
1433- n_spikes = sorting_analyzer .sorting .count_num_spikes_per_unit (unit_ids = unit_ids )
1433+ best_channels = get_template_extremum_channel (sorting_analyzer , outputs = "index" , peak_sign = peak_sign )
14341434
14351435 if correct_for_template_itself :
1436- tamplates_array = get_dense_templates_array (sorting_analyzer , return_in_uV = sorting_analyzer .return_in_uV )
1436+ n_spikes = sorting_analyzer .sorting .count_num_spikes_per_unit (unit_ids = unit_ids )
1437+ templates_array = get_dense_templates_array (sorting_analyzer , return_in_uV = sorting_analyzer .return_in_uV )
14371438
14381439 sd_ratio = {}
14391440
@@ -1468,21 +1469,17 @@ def compute_sd_ratio(
14681469 best_channel = best_channels [unit_id ]
14691470 std_noise = noise_levels [best_channel ]
14701471
1471- n_samples = sorting_analyzer .get_total_samples ()
1472-
14731472 if correct_for_template_itself :
14741473 # template = sorting_analyzer.get_template(unit_id, force_dense=True)[:, best_channel]
14751474 unit_index = sorting .id_to_index (unit_id )
1476-
1477- template = tamplates_array [unit_index , :, :][:, best_channel ]
1478- nsamples = template .shape [0 ]
1475+ template = templates_array [unit_index , :, best_channel ]
14791476
14801477 # Computing the variance of a trace that is all 0 and n_spikes non-overlapping template.
14811478 # TODO: Take into account that templates for different segments might differ.
1482- p = nsamples * n_spikes [unit_id ] / n_samples
1483- total_variance = p * np .mean (template ** 2 ) - p ** 2 * np .mean (template ) ** 2
1479+ p = len ( template ) * n_spikes [unit_id ] / sorting_analyzer . get_total_samples ()
1480+ template_variance = p * np .mean (template ** 2 ) - p ** 2 * np .mean (template ) ** 2
14841481
1485- std_noise = np .sqrt (std_noise ** 2 - total_variance )
1482+ std_noise = np .sqrt (std_noise ** 2 - template_variance )
14861483
14871484 sd_ratio [unit_id ] = unit_std / std_noise
14881485
@@ -1496,6 +1493,7 @@ class SDRatio(BaseMetric):
14961493 "censored_period_ms" : 4.0 ,
14971494 "correct_for_drift" : True ,
14981495 "correct_for_template_itself" : True ,
1496+ "peak_sign" : "neg" ,
14991497 }
15001498 metric_columns = {"sd_ratio" : float }
15011499 metric_descriptions = {
0 commit comments