From e07b4f5f01b5d958241548e1a91a2557c97364a3 Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Mon, 28 Apr 2025 16:11:42 +0100 Subject: [PATCH] add numba check to sd ratio --- src/spikeinterface/curation/curation_tools.py | 4 ++-- .../qualitymetrics/misc_metrics.py | 20 ++++++++++++------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/curation/curation_tools.py b/src/spikeinterface/curation/curation_tools.py index 3402638a16..a456564f29 100644 --- a/src/spikeinterface/curation/curation_tools.py +++ b/src/spikeinterface/curation/curation_tools.py @@ -127,10 +127,10 @@ def find_duplicated_spikes( assert seed is not None, "The 'seed' has to be provided if method=='random'" return _find_duplicated_spikes_random(spike_train, censored_period, seed) elif method == "keep_first_iterative": - assert HAVE_NUMBA, "'keep_first' method requires numba. Install it with >>> pip install numba" + assert HAVE_NUMBA, "'keep_first_iterative' method requires numba. Install it with >>> pip install numba" return _find_duplicated_spikes_keep_first_iterative(spike_train.astype(np.int64), censored_period) elif method == "keep_last_iterative": - assert HAVE_NUMBA, "'keep_last' method requires numba. Install it with >>> pip install numba" + assert HAVE_NUMBA, "'keep_last_iterative' method requires numba. Install it with >>> pip install numba" return _find_duplicated_spikes_keep_last_iterative(spike_train.astype(np.int64), censored_period) else: raise ValueError(f"Method '{method}' isn't a valid method for find_duplicated_spikes. Use one of {_methods}") diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 2c04350c9b..3a3aab0bf4 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -1469,7 +1469,8 @@ def compute_sd_ratio( num_spikes : dict The number of spikes, across all segments, for each unit ID. """ - from spikeinterface.curation.curation_tools import _find_duplicated_spikes_keep_first_iterative + + from spikeinterface.curation.curation_tools import find_duplicated_spikes kwargs, job_kwargs = split_job_kwargs(kwargs) job_kwargs = fix_job_kwargs(job_kwargs) @@ -1487,9 +1488,15 @@ def compute_sd_ratio( ) return {unit_id: np.nan for unit_id in unit_ids} + if not HAVE_NUMBA: + warnings.warn( + "'sd_ratio' metric computation requires numba. Install it with >>> pip install numba. " + "SD ratio metric will be set to NaN" + ) + return {unit_id: np.nan for unit_id in unit_ids} + if sorting_analyzer.has_extension("spike_amplitudes"): amplitudes_ext = sorting_analyzer.get_extension("spike_amplitudes") - # spike_amplitudes = amplitudes_ext.get_data(outputs="by_unit") spike_amplitudes = amplitudes_ext.get_data() else: warnings.warn( @@ -1516,18 +1523,17 @@ def compute_sd_ratio( spk_amp = [] for segment_index in range(sorting_analyzer.get_num_segments()): - # spike_train = sorting_analyzer.sorting.get_unit_spike_train(unit_id, segment_index=segment_index).astype( - # np.int64, copy=False - # ) + spike_mask = (spikes["unit_index"] == unit_index) & (spikes["segment_index"] == segment_index) spike_train = spikes[spike_mask]["sample_index"].astype(np.int64, copy=False) amplitudes = spike_amplitudes[spike_mask] - censored_indices = _find_duplicated_spikes_keep_first_iterative( + censored_indices = find_duplicated_spikes( spike_train, censored_period, + method="keep_first_iterative", ) - # spk_amp.append(np.delete(spike_amplitudes[segment_index][unit_id], censored_indices)) + spk_amp.append(np.delete(amplitudes, censored_indices)) spk_amp = np.concatenate([spk_amp[i] for i in range(len(spk_amp))])