Skip to content

Commit e07b4f5

Browse files
committed
add numba check to sd ratio
1 parent cb83327 commit e07b4f5

2 files changed

Lines changed: 15 additions & 9 deletions

File tree

src/spikeinterface/curation/curation_tools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,10 @@ def find_duplicated_spikes(
127127
assert seed is not None, "The 'seed' has to be provided if method=='random'"
128128
return _find_duplicated_spikes_random(spike_train, censored_period, seed)
129129
elif method == "keep_first_iterative":
130-
assert HAVE_NUMBA, "'keep_first' method requires numba. Install it with >>> pip install numba"
130+
assert HAVE_NUMBA, "'keep_first_iterative' method requires numba. Install it with >>> pip install numba"
131131
return _find_duplicated_spikes_keep_first_iterative(spike_train.astype(np.int64), censored_period)
132132
elif method == "keep_last_iterative":
133-
assert HAVE_NUMBA, "'keep_last' method requires numba. Install it with >>> pip install numba"
133+
assert HAVE_NUMBA, "'keep_last_iterative' method requires numba. Install it with >>> pip install numba"
134134
return _find_duplicated_spikes_keep_last_iterative(spike_train.astype(np.int64), censored_period)
135135
else:
136136
raise ValueError(f"Method '{method}' isn't a valid method for find_duplicated_spikes. Use one of {_methods}")

src/spikeinterface/qualitymetrics/misc_metrics.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1469,7 +1469,8 @@ def compute_sd_ratio(
14691469
num_spikes : dict
14701470
The number of spikes, across all segments, for each unit ID.
14711471
"""
1472-
from spikeinterface.curation.curation_tools import _find_duplicated_spikes_keep_first_iterative
1472+
1473+
from spikeinterface.curation.curation_tools import find_duplicated_spikes
14731474

14741475
kwargs, job_kwargs = split_job_kwargs(kwargs)
14751476
job_kwargs = fix_job_kwargs(job_kwargs)
@@ -1487,9 +1488,15 @@ def compute_sd_ratio(
14871488
)
14881489
return {unit_id: np.nan for unit_id in unit_ids}
14891490

1491+
if not HAVE_NUMBA:
1492+
warnings.warn(
1493+
"'sd_ratio' metric computation requires numba. Install it with >>> pip install numba. "
1494+
"SD ratio metric will be set to NaN"
1495+
)
1496+
return {unit_id: np.nan for unit_id in unit_ids}
1497+
14901498
if sorting_analyzer.has_extension("spike_amplitudes"):
14911499
amplitudes_ext = sorting_analyzer.get_extension("spike_amplitudes")
1492-
# spike_amplitudes = amplitudes_ext.get_data(outputs="by_unit")
14931500
spike_amplitudes = amplitudes_ext.get_data()
14941501
else:
14951502
warnings.warn(
@@ -1516,18 +1523,17 @@ def compute_sd_ratio(
15161523
spk_amp = []
15171524

15181525
for segment_index in range(sorting_analyzer.get_num_segments()):
1519-
# spike_train = sorting_analyzer.sorting.get_unit_spike_train(unit_id, segment_index=segment_index).astype(
1520-
# np.int64, copy=False
1521-
# )
1526+
15221527
spike_mask = (spikes["unit_index"] == unit_index) & (spikes["segment_index"] == segment_index)
15231528
spike_train = spikes[spike_mask]["sample_index"].astype(np.int64, copy=False)
15241529
amplitudes = spike_amplitudes[spike_mask]
15251530

1526-
censored_indices = _find_duplicated_spikes_keep_first_iterative(
1531+
censored_indices = find_duplicated_spikes(
15271532
spike_train,
15281533
censored_period,
1534+
method="keep_first_iterative",
15291535
)
1530-
# spk_amp.append(np.delete(spike_amplitudes[segment_index][unit_id], censored_indices))
1536+
15311537
spk_amp.append(np.delete(amplitudes, censored_indices))
15321538

15331539
spk_amp = np.concatenate([spk_amp[i] for i in range(len(spk_amp))])

0 commit comments

Comments
 (0)