Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/spikeinterface/curation/curation_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ def find_duplicated_spikes(
assert seed is not None, "The 'seed' has to be provided if method=='random'"
return _find_duplicated_spikes_random(spike_train, censored_period, seed)
elif method == "keep_first_iterative":
assert HAVE_NUMBA, "'keep_first' method requires numba. Install it with >>> pip install numba"
assert HAVE_NUMBA, "'keep_first_iterative' method requires numba. Install it with >>> pip install numba"
return _find_duplicated_spikes_keep_first_iterative(spike_train.astype(np.int64), censored_period)
elif method == "keep_last_iterative":
assert HAVE_NUMBA, "'keep_last' method requires numba. Install it with >>> pip install numba"
assert HAVE_NUMBA, "'keep_last_iterative' method requires numba. Install it with >>> pip install numba"
return _find_duplicated_spikes_keep_last_iterative(spike_train.astype(np.int64), censored_period)
else:
raise ValueError(f"Method '{method}' isn't a valid method for find_duplicated_spikes. Use one of {_methods}")
Expand Down
20 changes: 13 additions & 7 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1469,7 +1469,8 @@ def compute_sd_ratio(
num_spikes : dict
The number of spikes, across all segments, for each unit ID.
"""
from spikeinterface.curation.curation_tools import _find_duplicated_spikes_keep_first_iterative

from spikeinterface.curation.curation_tools import find_duplicated_spikes

kwargs, job_kwargs = split_job_kwargs(kwargs)
job_kwargs = fix_job_kwargs(job_kwargs)
Expand All @@ -1487,9 +1488,15 @@ def compute_sd_ratio(
)
return {unit_id: np.nan for unit_id in unit_ids}

if not HAVE_NUMBA:
warnings.warn(
"'sd_ratio' metric computation requires numba. Install it with >>> pip install numba. "
"SD ratio metric will be set to NaN"
)
return {unit_id: np.nan for unit_id in unit_ids}
Comment thread
alejoe91 marked this conversation as resolved.

if sorting_analyzer.has_extension("spike_amplitudes"):
amplitudes_ext = sorting_analyzer.get_extension("spike_amplitudes")
# spike_amplitudes = amplitudes_ext.get_data(outputs="by_unit")
spike_amplitudes = amplitudes_ext.get_data()
else:
warnings.warn(
Expand All @@ -1516,18 +1523,17 @@ def compute_sd_ratio(
spk_amp = []

for segment_index in range(sorting_analyzer.get_num_segments()):
# spike_train = sorting_analyzer.sorting.get_unit_spike_train(unit_id, segment_index=segment_index).astype(
# np.int64, copy=False
# )

spike_mask = (spikes["unit_index"] == unit_index) & (spikes["segment_index"] == segment_index)
spike_train = spikes[spike_mask]["sample_index"].astype(np.int64, copy=False)
amplitudes = spike_amplitudes[spike_mask]

censored_indices = _find_duplicated_spikes_keep_first_iterative(
censored_indices = find_duplicated_spikes(
spike_train,
censored_period,
method="keep_first_iterative",
)
# spk_amp.append(np.delete(spike_amplitudes[segment_index][unit_id], censored_indices))

spk_amp.append(np.delete(amplitudes, censored_indices))

spk_amp = np.concatenate([spk_amp[i] for i in range(len(spk_amp))])
Expand Down