@@ -1469,7 +1469,8 @@ def compute_sd_ratio(
14691469 num_spikes : dict
14701470 The number of spikes, across all segments, for each unit ID.
14711471 """
1472- from spikeinterface .curation .curation_tools import _find_duplicated_spikes_keep_first_iterative
1472+
1473+ from spikeinterface .curation .curation_tools import find_duplicated_spikes
14731474
14741475 kwargs , job_kwargs = split_job_kwargs (kwargs )
14751476 job_kwargs = fix_job_kwargs (job_kwargs )
@@ -1487,9 +1488,15 @@ def compute_sd_ratio(
14871488 )
14881489 return {unit_id : np .nan for unit_id in unit_ids }
14891490
1491+ if not HAVE_NUMBA :
1492+ warnings .warn (
1493+ "'sd_ratio' metric computation requires numba. Install it with >>> pip install numba. "
1494+ "SD ratio metric will be set to NaN"
1495+ )
1496+ return {unit_id : np .nan for unit_id in unit_ids }
1497+
14901498 if sorting_analyzer .has_extension ("spike_amplitudes" ):
14911499 amplitudes_ext = sorting_analyzer .get_extension ("spike_amplitudes" )
1492- # spike_amplitudes = amplitudes_ext.get_data(outputs="by_unit")
14931500 spike_amplitudes = amplitudes_ext .get_data ()
14941501 else :
14951502 warnings .warn (
@@ -1516,18 +1523,17 @@ def compute_sd_ratio(
15161523 spk_amp = []
15171524
15181525 for segment_index in range (sorting_analyzer .get_num_segments ()):
1519- # spike_train = sorting_analyzer.sorting.get_unit_spike_train(unit_id, segment_index=segment_index).astype(
1520- # np.int64, copy=False
1521- # )
1526+
15221527 spike_mask = (spikes ["unit_index" ] == unit_index ) & (spikes ["segment_index" ] == segment_index )
15231528 spike_train = spikes [spike_mask ]["sample_index" ].astype (np .int64 , copy = False )
15241529 amplitudes = spike_amplitudes [spike_mask ]
15251530
1526- censored_indices = _find_duplicated_spikes_keep_first_iterative (
1531+ censored_indices = find_duplicated_spikes (
15271532 spike_train ,
15281533 censored_period ,
1534+ method = "keep_first_iterative" ,
15291535 )
1530- # spk_amp.append(np.delete(spike_amplitudes[segment_index][unit_id], censored_indices))
1536+
15311537 spk_amp .append (np .delete (amplitudes , censored_indices ))
15321538
15331539 spk_amp = np .concatenate ([spk_amp [i ] for i in range (len (spk_amp ))])
0 commit comments