From 2ed9c37e3e94a957a4ae6f1c607dfc85f4b68272 Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Thu, 10 Apr 2025 09:53:36 -0400 Subject: [PATCH 01/13] take 2 on export_to_ibl --- src/spikeinterface/exporters/__init__.py | 1 + src/spikeinterface/exporters/to_ibl.py | 294 +++++++++++++++++++++++ 2 files changed, 295 insertions(+) create mode 100644 src/spikeinterface/exporters/to_ibl.py diff --git a/src/spikeinterface/exporters/__init__.py b/src/spikeinterface/exporters/__init__.py index 50fcc304d1..6f872b9c3e 100644 --- a/src/spikeinterface/exporters/__init__.py +++ b/src/spikeinterface/exporters/__init__.py @@ -1,2 +1,3 @@ from .to_phy import export_to_phy from .report import export_report +from .to_ibl import export_to_ibl \ No newline at end of file diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py new file mode 100644 index 0000000000..649fbc015b --- /dev/null +++ b/src/spikeinterface/exporters/to_ibl.py @@ -0,0 +1,294 @@ +from __future__ import annotations + +import os +import shutil +import warnings +from pathlib import Path +from typing import Optional + +import numpy as np +import numpy.typing as npt +from tqdm.auto import tqdm + +from spikeinterface.core import ChannelSparsity, SortingAnalyzer +from spikeinterface.core.job_tools import divide_segment_into_chunks +from spikeinterface.core.template_tools import get_template_extremum_channel +from spikeinterface.exporters import ( + export_to_phy, +) + + +def export_to_ibl( + analyzer: SortingAnalyzer, + output_folder: str | Path, + lfp_recording=None, + rms_win_length_sec=3, + welch_win_length_samples=2**14, + total_secs_spec_dens=100, + only_ibl_specific_steps=False, + sparsity: Optional[ChannelSparsity] = None, + remove_if_exists: bool = False, + verbose: bool = True, + **job_kwargs, +): + """ + Exports a sorting analyzer to the IBL gui format (similar to the Phy format with some extras). + + Parameters + ---------- + analyzer: SortingAnalyzer + The sorting analyzer object to use for spike information. + Should also contain the pre-processed recording to use for AP-band data. + output_folder: str | Path + The output folder for the exports. + lfp_recording: Any SI Recording object, default None + The pre-processed recording to use for LFP data. If None, the LFP data is not exported. + rms_win_length_sec: float, default: 3 + The window length in seconds for the RMS calculation (on the LFP data). + welch_win_length_samples: int, default: 2^14 + The window length in samples for the Welch spectral density computation (on the LFP data). + total_secs_spec_dens: int, default: 100 + The total number of seconds to use for the spectral density calculation. + only_ibl_specific_steps: bool, default: False + If True, only the IBL specific steps are run (i.e. skips calling `export_to_phy`) + sparsity: ChannelSparsity or None, default: None + The sparsity object (currently only respected for phy part of the export) + remove_if_exists: bool, default: False + If True and "output_folder" exists, it is removed and overwritten + verbose: bool, default: True + If True, output is verbose + + """ + + try: + from scipy.signal import welch + except ImportError as e: + raise ImportError( + "Please install scipy to use the export_to_ibl function." + ) from e + + # Output folder checks + if isinstance(output_folder, str): + output_folder = Path(output_folder) + output_folder = Path(output_folder).absolute() + if output_folder.is_dir(): + if remove_if_exists: + shutil.rmtree(output_folder) + else: + raise FileExistsError(f"{output_folder} already exists") + else: + pass + # don't make the output dir yet, b/c export_to_phy will do that for us. + + if verbose: + print("Exporting recording to IBL format...") + + # Compute any missing extensions + available_extension_names = analyzer.get_saved_extension_names() + required_exts = [ + "templates", + "template_similarity", + "spike_locations", + "noise_levels", + "quality_metrics", + ] + required_qms = ["amplitude_median", "isi_violations_ratio", "amplitude_cutoff"] + for ext in required_exts: + if ext not in available_extension_names: + if ext == "quality_metrics": + kwargs = {"skip_pc_metrics": False} + else: + kwargs = {} + analyzer.compute(ext, verbose=verbose, **kwargs) + elif ext == "quality_metrics": + qm = analyzer.get_extension("quality_metrics").get_data() + for rqm in required_qms: + if rqm not in qm: + analyzer.compute( + "quality_metrics", + metric_names=[rqm], + verbose=verbose, + ) + + # # Start by just exporting to phy + if not only_ibl_specific_steps: + if verbose: + print("Doing phy-like export...") + export_to_phy( + analyzer, + output_folder, + compute_amplitudes=True, + compute_pc_features=False, + sparsity=sparsity, + copy_binary=False, + template_mode="median", + verbose=verbose, + use_relative_path=False, + **job_kwargs, + ) + + # Make sure output dir exists, in case user skips export_to_phy + if not output_folder.is_dir(): + os.makedirs(output_folder) + + if verbose: + print("Running IBL-specific steps...") + + # Now we need to add the extra IBL specific files + # See here for docs on the format: https://github.com/int-brain-lab/iblapps/wiki/3.-Overview-of-datasets#input-histology-data + + # Subset channels in case some were excluded from spike sorting + (channel_inds,) = np.isin( + analyzer.recording.channel_ids, analyzer.channel_ids + ).nonzero() + + # TODO: put this into a chunk extractor + def _get_rms(rec): + chunk_nframes = int(rms_win_length_sec * rec.sampling_frequency) + chunks = divide_segment_into_chunks(rec.get_num_samples(), chunk_nframes) + chunk_rms = np.zeros((len(chunks), rec.get_num_channels())) + chunk_start_times = np.zeros((len(chunks),)) + for iChunk, (start_frame, stop_frame) in enumerate(tqdm(chunks)): + traces = rec.get_traces(start_frame=start_frame, end_frame=stop_frame) + chunk_rms[iChunk, :] = np.sqrt(np.mean(traces**2, axis=0)) + chunk_start_times[iChunk] = start_frame / rec.sampling_frequency + chunk_rms = chunk_rms[:, channel_inds] + chunk_rms = chunk_rms.astype(np.float32) + chunk_start_times = chunk_start_times.astype(np.float32) + return chunk_rms, chunk_start_times + + # Get RMS for the AP data. We will use a window of length rms_win_length_sec seconds slid over the entire recording. + ap_rec = analyzer.recording + if ap_rec.get_num_segments() != 1: + warnings.warn( + "Found ap recording with more than one segment, only using initial segment." + ) + ap_rec = ap_rec[0] + chunk_rms, chunk_start_times = _get_rms(ap_rec) + np.save(os.path.join(output_folder, "_iblqc_ephysTimeRmsAP.rms.npy"), chunk_rms) + np.save( + os.path.join(output_folder, "_iblqc_ephysTimeRmsAP.timestamps.npy"), + chunk_start_times, + ) + + if lfp_recording is not None: + # Get RMS for the LFP data. + if lfp_recording.get_num_segments() != 1: + warnings.warn( + "Found lfp recording with more than one segment, only using initial segment." + ) + lfp_recording = lfp_recording[0] + chunk_rms, chunk_start_times = _get_rms(lfp_recording) + np.save(os.path.join(output_folder, "_iblqc_ephysTimeRmsLF.rms.npy"), chunk_rms) + np.save( + os.path.join(output_folder, "_iblqc_ephysTimeRmsLF.timestamps.npy"), + chunk_start_times, + ) + + # Get spectral density on a snippet of LFP data + end_frame = int(total_secs_spec_dens * lfp_recording.sampling_frequency) + traces = lfp_recording.get_traces( + start_frame=0, end_frame=end_frame + ) # time x channels + spec_density = np.zeros((welch_win_length_samples // 2 + 1, traces.shape[1])) + for iCh in range(traces.shape[1]): + f, Pxx = welch( + traces[:, iCh], + fs=lfp_recording.sampling_frequency, + nperseg=welch_win_length_samples, + ) + spec_density[:, iCh] = Pxx + spec_density = spec_density[ + :, channel_inds + ] # only keep channels that were used for spike sorting + spec_density = spec_density.astype(np.float32) + f = f.astype(np.float32) + assert spec_density.shape[0] == len(f) + np.save( + os.path.join(output_folder, "_iblqc_ephysSpectralDensityLF.power.npy"), + spec_density, + ) + np.save( + os.path.join(output_folder, "_iblqc_ephysSpectralDensityLF.freqs.npy"), f + ) + + ### Save spike info ### + + spike_locations = analyzer.load_extension("spike_locations").get_data() + spike_depths = spike_locations["y"] + + # convert clusters and squeeze + clusters = np.load(output_folder / "spike_clusters.npy") + np.save(output_folder / "spike_clusters.npy", np.squeeze(clusters.astype("uint32"))) + + # convert times and squeeze + times = np.load(output_folder / "spike_times.npy") + np.save( + output_folder / "spike_times.npy", np.squeeze(times / 30000.0).astype("float64") + ) + + # convert amplitudes and squeeze + amps = np.load(output_folder / "amplitudes.npy") + np.save(output_folder / "amplitudes.npy", np.squeeze(-amps / 1e6).astype("float64")) + + # save depths and channel inds + np.save(output_folder / "spike_depths.npy", spike_depths) + np.save( + output_folder / "channel_inds.npy", np.arange(len(channel_inds), dtype="int") + ) + + # # save templates + cluster_channels = [] + cluster_peakToTrough = [] + cluster_waveforms = [] + templates = analyzer.get_extension("templates").get_data() + extremum_channel_indices = get_template_extremum_channel(analyzer, outputs="index") + + for unit_idx, unit_id in enumerate(analyzer.unit_ids): + waveform = templates[unit_idx, :, :] + extremum_channel_index = extremum_channel_indices[unit_id] + peak_waveform = waveform[:, extremum_channel_index] + peakToTrough = ( + np.argmax(peak_waveform) - np.argmin(peak_waveform) + ) / analyzer.sampling_frequency + # cluster_channels.append(int(channel_locs[extremum_channel_index, 1] / 10)) # ??? fails for odd nums of units + cluster_channels.append( + extremum_channel_index + ) # see: https://github.com/SpikeInterface/spikeinterface/issues/2843#issuecomment-2148164870 + cluster_peakToTrough.append(peakToTrough) + cluster_waveforms.append(waveform) + + np.save(output_folder / "cluster_peakToTrough.npy", np.array(cluster_peakToTrough)) + np.save(output_folder / "cluster_waveforms.npy", np.stack(cluster_waveforms)) + np.save(output_folder / "cluster_channels.npy", np.array(cluster_channels)) + + # rename files from this func and the phy export func + _FILE_RENAMES = [ # file_in, file_out + ("channel_positions.npy", "channels.localCoordinates.npy"), + ("channel_inds.npy", "channels.rawInd.npy"), + ("cluster_peakToTrough.npy", "clusters.peakToTrough.npy"), + ("cluster_channels.npy", "clusters.channels.npy"), + ("cluster_waveforms.npy", "clusters.waveforms.npy"), + ("spike_clusters.npy", "spikes.clusters.npy"), + ("amplitudes.npy", "spikes.amps.npy"), + ("spike_depths.npy", "spikes.depths.npy"), + ("spike_times.npy", "spikes.times.npy"), + ] + + for names in _FILE_RENAMES: + old_name = output_folder / names[0] + new_name = output_folder / names[1] + os.rename(old_name, new_name) + + # save quality metrics + qm = analyzer.load_extension("quality_metrics") + qm_data = qm.get_data() + qm_data.index.name = "cluster_id" + qm_data["cluster_id.1"] = qm_data.index.values + good_ibl = ( # rough estimate of ibl standards + (qm_data["amplitude_median"] > 50) + & (qm_data["isi_violations_ratio"] < 0.2) + & (qm_data["amplitude_cutoff"] < 0.1) + ) + qm_data["label"] = good_ibl.astype("int") + qm_data.to_csv(output_folder / "clusters.metrics.csv") From b4f1bf4600b28624d61f06dc2b6882db3f21bee5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Apr 2025 14:01:10 +0000 Subject: [PATCH 02/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/exporters/__init__.py | 2 +- src/spikeinterface/exporters/to_ibl.py | 42 +++++++----------------- 2 files changed, 12 insertions(+), 32 deletions(-) diff --git a/src/spikeinterface/exporters/__init__.py b/src/spikeinterface/exporters/__init__.py index 6f872b9c3e..1bec3e5cab 100644 --- a/src/spikeinterface/exporters/__init__.py +++ b/src/spikeinterface/exporters/__init__.py @@ -1,3 +1,3 @@ from .to_phy import export_to_phy from .report import export_report -from .to_ibl import export_to_ibl \ No newline at end of file +from .to_ibl import export_to_ibl diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index 649fbc015b..4b4433a7ee 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -37,7 +37,7 @@ def export_to_ibl( Parameters ---------- analyzer: SortingAnalyzer - The sorting analyzer object to use for spike information. + The sorting analyzer object to use for spike information. Should also contain the pre-processed recording to use for AP-band data. output_folder: str | Path The output folder for the exports. @@ -63,9 +63,7 @@ def export_to_ibl( try: from scipy.signal import welch except ImportError as e: - raise ImportError( - "Please install scipy to use the export_to_ibl function." - ) from e + raise ImportError("Please install scipy to use the export_to_ibl function.") from e # Output folder checks if isinstance(output_folder, str): @@ -138,9 +136,7 @@ def export_to_ibl( # See here for docs on the format: https://github.com/int-brain-lab/iblapps/wiki/3.-Overview-of-datasets#input-histology-data # Subset channels in case some were excluded from spike sorting - (channel_inds,) = np.isin( - analyzer.recording.channel_ids, analyzer.channel_ids - ).nonzero() + (channel_inds,) = np.isin(analyzer.recording.channel_ids, analyzer.channel_ids).nonzero() # TODO: put this into a chunk extractor def _get_rms(rec): @@ -160,9 +156,7 @@ def _get_rms(rec): # Get RMS for the AP data. We will use a window of length rms_win_length_sec seconds slid over the entire recording. ap_rec = analyzer.recording if ap_rec.get_num_segments() != 1: - warnings.warn( - "Found ap recording with more than one segment, only using initial segment." - ) + warnings.warn("Found ap recording with more than one segment, only using initial segment.") ap_rec = ap_rec[0] chunk_rms, chunk_start_times = _get_rms(ap_rec) np.save(os.path.join(output_folder, "_iblqc_ephysTimeRmsAP.rms.npy"), chunk_rms) @@ -174,9 +168,7 @@ def _get_rms(rec): if lfp_recording is not None: # Get RMS for the LFP data. if lfp_recording.get_num_segments() != 1: - warnings.warn( - "Found lfp recording with more than one segment, only using initial segment." - ) + warnings.warn("Found lfp recording with more than one segment, only using initial segment.") lfp_recording = lfp_recording[0] chunk_rms, chunk_start_times = _get_rms(lfp_recording) np.save(os.path.join(output_folder, "_iblqc_ephysTimeRmsLF.rms.npy"), chunk_rms) @@ -187,9 +179,7 @@ def _get_rms(rec): # Get spectral density on a snippet of LFP data end_frame = int(total_secs_spec_dens * lfp_recording.sampling_frequency) - traces = lfp_recording.get_traces( - start_frame=0, end_frame=end_frame - ) # time x channels + traces = lfp_recording.get_traces(start_frame=0, end_frame=end_frame) # time x channels spec_density = np.zeros((welch_win_length_samples // 2 + 1, traces.shape[1])) for iCh in range(traces.shape[1]): f, Pxx = welch( @@ -198,9 +188,7 @@ def _get_rms(rec): nperseg=welch_win_length_samples, ) spec_density[:, iCh] = Pxx - spec_density = spec_density[ - :, channel_inds - ] # only keep channels that were used for spike sorting + spec_density = spec_density[:, channel_inds] # only keep channels that were used for spike sorting spec_density = spec_density.astype(np.float32) f = f.astype(np.float32) assert spec_density.shape[0] == len(f) @@ -208,9 +196,7 @@ def _get_rms(rec): os.path.join(output_folder, "_iblqc_ephysSpectralDensityLF.power.npy"), spec_density, ) - np.save( - os.path.join(output_folder, "_iblqc_ephysSpectralDensityLF.freqs.npy"), f - ) + np.save(os.path.join(output_folder, "_iblqc_ephysSpectralDensityLF.freqs.npy"), f) ### Save spike info ### @@ -223,9 +209,7 @@ def _get_rms(rec): # convert times and squeeze times = np.load(output_folder / "spike_times.npy") - np.save( - output_folder / "spike_times.npy", np.squeeze(times / 30000.0).astype("float64") - ) + np.save(output_folder / "spike_times.npy", np.squeeze(times / 30000.0).astype("float64")) # convert amplitudes and squeeze amps = np.load(output_folder / "amplitudes.npy") @@ -233,9 +217,7 @@ def _get_rms(rec): # save depths and channel inds np.save(output_folder / "spike_depths.npy", spike_depths) - np.save( - output_folder / "channel_inds.npy", np.arange(len(channel_inds), dtype="int") - ) + np.save(output_folder / "channel_inds.npy", np.arange(len(channel_inds), dtype="int")) # # save templates cluster_channels = [] @@ -248,9 +230,7 @@ def _get_rms(rec): waveform = templates[unit_idx, :, :] extremum_channel_index = extremum_channel_indices[unit_id] peak_waveform = waveform[:, extremum_channel_index] - peakToTrough = ( - np.argmax(peak_waveform) - np.argmin(peak_waveform) - ) / analyzer.sampling_frequency + peakToTrough = (np.argmax(peak_waveform) - np.argmin(peak_waveform)) / analyzer.sampling_frequency # cluster_channels.append(int(channel_locs[extremum_channel_index, 1] / 10)) # ??? fails for odd nums of units cluster_channels.append( extremum_channel_index From 74edab9c9a6166922432c6ecfcc8561e1645037b Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Fri, 11 Apr 2025 09:28:21 -0400 Subject: [PATCH 03/13] minor changes from PR review --- src/spikeinterface/exporters/to_ibl.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index 4b4433a7ee..f49e9118b1 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -1,5 +1,6 @@ from __future__ import annotations +from importlib.util import find_spec import os import shutil import warnings @@ -7,7 +8,6 @@ from typing import Optional import numpy as np -import numpy.typing as npt from tqdm.auto import tqdm from spikeinterface.core import ChannelSparsity, SortingAnalyzer @@ -60,10 +60,10 @@ def export_to_ibl( """ - try: + if find_spec("scipy") is None: + raise ImportError("Please install scipy to use the export_to_ibl function.") + else: from scipy.signal import welch - except ImportError as e: - raise ImportError("Please install scipy to use the export_to_ibl function.") from e # Output folder checks if isinstance(output_folder, str): @@ -209,7 +209,7 @@ def _get_rms(rec): # convert times and squeeze times = np.load(output_folder / "spike_times.npy") - np.save(output_folder / "spike_times.npy", np.squeeze(times / 30000.0).astype("float64")) + np.save(output_folder / "spike_times.npy", np.squeeze(times / analyzer.sampling_frequency).astype("float64")) # convert amplitudes and squeeze amps = np.load(output_folder / "amplitudes.npy") @@ -221,24 +221,24 @@ def _get_rms(rec): # # save templates cluster_channels = [] - cluster_peakToTrough = [] + cluster_peak_to_trough = [] cluster_waveforms = [] templates = analyzer.get_extension("templates").get_data() extremum_channel_indices = get_template_extremum_channel(analyzer, outputs="index") - for unit_idx, unit_id in enumerate(analyzer.unit_ids): - waveform = templates[unit_idx, :, :] + for unit_index, unit_id in enumerate(analyzer.unit_ids): + waveform = templates[unit_index, :, :] extremum_channel_index = extremum_channel_indices[unit_id] peak_waveform = waveform[:, extremum_channel_index] - peakToTrough = (np.argmax(peak_waveform) - np.argmin(peak_waveform)) / analyzer.sampling_frequency + peak_to_trough = (np.argmax(peak_waveform) - np.argmin(peak_waveform)) / analyzer.sampling_frequency # cluster_channels.append(int(channel_locs[extremum_channel_index, 1] / 10)) # ??? fails for odd nums of units cluster_channels.append( extremum_channel_index ) # see: https://github.com/SpikeInterface/spikeinterface/issues/2843#issuecomment-2148164870 - cluster_peakToTrough.append(peakToTrough) + cluster_peak_to_trough.append(peak_to_trough) cluster_waveforms.append(waveform) - np.save(output_folder / "cluster_peakToTrough.npy", np.array(cluster_peakToTrough)) + np.save(output_folder / "cluster_peak_to_trough.npy", np.array(cluster_peak_to_trough)) np.save(output_folder / "cluster_waveforms.npy", np.stack(cluster_waveforms)) np.save(output_folder / "cluster_channels.npy", np.array(cluster_channels)) @@ -246,7 +246,7 @@ def _get_rms(rec): _FILE_RENAMES = [ # file_in, file_out ("channel_positions.npy", "channels.localCoordinates.npy"), ("channel_inds.npy", "channels.rawInd.npy"), - ("cluster_peakToTrough.npy", "clusters.peakToTrough.npy"), + ("cluster_peak_to_trough.npy", "clusters.peakToTrough.npy"), ("cluster_channels.npy", "clusters.channels.npy"), ("cluster_waveforms.npy", "clusters.waveforms.npy"), ("spike_clusters.npy", "spikes.clusters.npy"), From eaf847232a684b4177153238296f7aa51b8754f6 Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Fri, 11 Apr 2025 10:36:09 -0400 Subject: [PATCH 04/13] test for ibl exporter --- .../exporters/tests/test_export_to_ibl.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 src/spikeinterface/exporters/tests/test_export_to_ibl.py diff --git a/src/spikeinterface/exporters/tests/test_export_to_ibl.py b/src/spikeinterface/exporters/tests/test_export_to_ibl.py new file mode 100644 index 0000000000..406a467efd --- /dev/null +++ b/src/spikeinterface/exporters/tests/test_export_to_ibl.py @@ -0,0 +1,28 @@ +import shutil + +import numpy as np + +from spikeinterface.exporters import export_to_ibl + +from spikeinterface.exporters.tests.common import ( + make_sorting_analyzer, + sorting_analyzer_sparse_for_export, +) + +def test_export_to_ibl(sorting_analyzer_sparse_for_export, create_cache_folder): + cache_folder = create_cache_folder + output_folder = cache_folder / "ibl_output" + for f in (output_folder,): + if f.is_dir(): + shutil.rmtree(f) + + export_to_ibl( + sorting_analyzer_sparse_for_export, + output_folder, + lfp_recording=sorting_analyzer_sparse_for_export.recording, + ) + + +if __name__ == "__main__": + sorting_analyzer = make_sorting_analyzer(sparse=True) + test_export_to_ibl(sorting_analyzer) \ No newline at end of file From de22a6c761d9c168245894582b96994df6e5973d Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Fri, 11 Apr 2025 10:36:43 -0400 Subject: [PATCH 05/13] test fixes --- src/spikeinterface/exporters/to_ibl.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index f49e9118b1..0a9bde5772 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -82,7 +82,10 @@ def export_to_ibl( print("Exporting recording to IBL format...") # Compute any missing extensions - available_extension_names = analyzer.get_saved_extension_names() + if analyzer.format == "memory": + available_extension_names = list(analyzer.extensions.keys()) + else: + available_extension_names = analyzer.get_saved_extension_names() required_exts = [ "templates", "template_similarity", @@ -90,7 +93,6 @@ def export_to_ibl( "noise_levels", "quality_metrics", ] - required_qms = ["amplitude_median", "isi_violations_ratio", "amplitude_cutoff"] for ext in required_exts: if ext not in available_extension_names: if ext == "quality_metrics": @@ -200,7 +202,7 @@ def _get_rms(rec): ### Save spike info ### - spike_locations = analyzer.load_extension("spike_locations").get_data() + spike_locations = analyzer.get_extension("spike_locations").get_data() spike_depths = spike_locations["y"] # convert clusters and squeeze @@ -261,7 +263,7 @@ def _get_rms(rec): os.rename(old_name, new_name) # save quality metrics - qm = analyzer.load_extension("quality_metrics") + qm = analyzer.get_extension("quality_metrics") qm_data = qm.get_data() qm_data.index.name = "cluster_id" qm_data["cluster_id.1"] = qm_data.index.values From 6eb068d5192ca3bebdf62d8a3ac4ba234dff6c61 Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Fri, 11 Apr 2025 10:37:00 -0400 Subject: [PATCH 06/13] fix amp sign, loosen ibl standards --- src/spikeinterface/exporters/to_ibl.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index 0a9bde5772..30a2606e29 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -267,10 +267,12 @@ def _get_rms(rec): qm_data = qm.get_data() qm_data.index.name = "cluster_id" qm_data["cluster_id.1"] = qm_data.index.values - good_ibl = ( # rough estimate of ibl standards - (qm_data["amplitude_median"] > 50) - & (qm_data["isi_violations_ratio"] < 0.2) - & (qm_data["amplitude_cutoff"] < 0.1) + amplitude_sign_coef = -1 if analyzer.get_extension("spike_amplitudes").params["peak_sign"] == "neg" else 1 + + good_ibl = ( # rough, slightly looser estimate of ibl standards + ((amplitude_sign_coef * qm_data["amplitude_median"]) > 40) + & (qm_data["isi_violations_ratio"] < 0.5) + & (qm_data["amplitude_cutoff"] < 0.2) ) qm_data["label"] = good_ibl.astype("int") qm_data.to_csv(output_folder / "clusters.metrics.csv") From 789236cc1fc1f7a0b4c58d31cd93814d2fb23781 Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Fri, 11 Apr 2025 10:37:28 -0400 Subject: [PATCH 07/13] simplify qm checks --- src/spikeinterface/exporters/to_ibl.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index 30a2606e29..f20468d2a7 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -96,19 +96,21 @@ def export_to_ibl( for ext in required_exts: if ext not in available_extension_names: if ext == "quality_metrics": - kwargs = {"skip_pc_metrics": False} + kwargs = {"skip_pc_metrics": True} else: kwargs = {} analyzer.compute(ext, verbose=verbose, **kwargs) - elif ext == "quality_metrics": - qm = analyzer.get_extension("quality_metrics").get_data() - for rqm in required_qms: - if rqm not in qm: - analyzer.compute( - "quality_metrics", - metric_names=[rqm], - verbose=verbose, - ) + + # Check in case user pre-calculated a small set of qm's that aren't enough for IBL + required_qms = ["amplitude_median", "isi_violation", "amplitude_cutoff"] + qm = analyzer.get_extension("quality_metrics").get_data() + for rqm in required_qms: + if rqm not in qm: + analyzer.compute( + "quality_metrics", + metric_names=[rqm], + verbose=verbose, + ) # # Start by just exporting to phy if not only_ibl_specific_steps: From 4913ada33ce0d0335668129ea70114047d797f54 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 11 Apr 2025 14:38:39 +0000 Subject: [PATCH 08/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/exporters/tests/test_export_to_ibl.py | 3 ++- src/spikeinterface/exporters/to_ibl.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/exporters/tests/test_export_to_ibl.py b/src/spikeinterface/exporters/tests/test_export_to_ibl.py index 406a467efd..c5905c12de 100644 --- a/src/spikeinterface/exporters/tests/test_export_to_ibl.py +++ b/src/spikeinterface/exporters/tests/test_export_to_ibl.py @@ -9,6 +9,7 @@ sorting_analyzer_sparse_for_export, ) + def test_export_to_ibl(sorting_analyzer_sparse_for_export, create_cache_folder): cache_folder = create_cache_folder output_folder = cache_folder / "ibl_output" @@ -25,4 +26,4 @@ def test_export_to_ibl(sorting_analyzer_sparse_for_export, create_cache_folder): if __name__ == "__main__": sorting_analyzer = make_sorting_analyzer(sparse=True) - test_export_to_ibl(sorting_analyzer) \ No newline at end of file + test_export_to_ibl(sorting_analyzer) diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index f20468d2a7..1c222a709b 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -100,8 +100,8 @@ def export_to_ibl( else: kwargs = {} analyzer.compute(ext, verbose=verbose, **kwargs) - - # Check in case user pre-calculated a small set of qm's that aren't enough for IBL + + # Check in case user pre-calculated a small set of qm's that aren't enough for IBL required_qms = ["amplitude_median", "isi_violation", "amplitude_cutoff"] qm = analyzer.get_extension("quality_metrics").get_data() for rqm in required_qms: @@ -270,7 +270,7 @@ def _get_rms(rec): qm_data.index.name = "cluster_id" qm_data["cluster_id.1"] = qm_data.index.values amplitude_sign_coef = -1 if analyzer.get_extension("spike_amplitudes").params["peak_sign"] == "neg" else 1 - + good_ibl = ( # rough, slightly looser estimate of ibl standards ((amplitude_sign_coef * qm_data["amplitude_median"]) > 40) & (qm_data["isi_violations_ratio"] < 0.5) From b280a73c75837ee263ac27ed350d2ee746596706 Mon Sep 17 00:00:00 2001 From: jonahpearl Date: Sat, 12 Apr 2025 09:34:47 -0400 Subject: [PATCH 09/13] round 2 of code review --- src/spikeinterface/exporters/to_ibl.py | 43 ++++++++++++-------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index 1c222a709b..9d60eadb4c 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -61,7 +61,7 @@ def export_to_ibl( """ if find_spec("scipy") is None: - raise ImportError("Please install scipy to use the export_to_ibl function.") + raise ImportError("Please install scipy to use the `export_to_ibl` function.") else: from scipy.signal import welch @@ -104,13 +104,8 @@ def export_to_ibl( # Check in case user pre-calculated a small set of qm's that aren't enough for IBL required_qms = ["amplitude_median", "isi_violation", "amplitude_cutoff"] qm = analyzer.get_extension("quality_metrics").get_data() - for rqm in required_qms: - if rqm not in qm: - analyzer.compute( - "quality_metrics", - metric_names=[rqm], - verbose=verbose, - ) + qms_to_compute = [metric for metric in required_qms if metric not in qm] + analyzer.compute("quality_metrics", metric_names=qms_to_compute, verbose=verbose) # # Start by just exporting to phy if not only_ibl_specific_steps: @@ -157,17 +152,20 @@ def _get_rms(rec): chunk_start_times = chunk_start_times.astype(np.float32) return chunk_rms, chunk_start_times - # Get RMS for the AP data. We will use a window of length rms_win_length_sec seconds slid over the entire recording. - ap_rec = analyzer.recording - if ap_rec.get_num_segments() != 1: - warnings.warn("Found ap recording with more than one segment, only using initial segment.") - ap_rec = ap_rec[0] - chunk_rms, chunk_start_times = _get_rms(ap_rec) - np.save(os.path.join(output_folder, "_iblqc_ephysTimeRmsAP.rms.npy"), chunk_rms) - np.save( - os.path.join(output_folder, "_iblqc_ephysTimeRmsAP.timestamps.npy"), - chunk_start_times, - ) + if analyzer.has_recording(): + # Get RMS for the AP data. We will use a window of length rms_win_length_sec seconds slid over the entire recording. + ap_rec = analyzer.recording + if ap_rec.get_num_segments() != 1: + warnings.warn("Found ap recording with more than one segment, only using initial segment.") + ap_rec = ap_rec[0] + chunk_rms, chunk_start_times = _get_rms(ap_rec) + np.save(os.path.join(output_folder, "_iblqc_ephysTimeRmsAP.rms.npy"), chunk_rms) + np.save( + os.path.join(output_folder, "_iblqc_ephysTimeRmsAP.timestamps.npy"), + chunk_start_times, + ) + elif verbose: + print("No recording data found in the SortingAnalyzer, skipping AP RMS calculation.") if lfp_recording is not None: # Get RMS for the LFP data. @@ -186,7 +184,7 @@ def _get_rms(rec): traces = lfp_recording.get_traces(start_frame=0, end_frame=end_frame) # time x channels spec_density = np.zeros((welch_win_length_samples // 2 + 1, traces.shape[1])) for iCh in range(traces.shape[1]): - f, Pxx = welch( + freqs, Pxx = welch( traces[:, iCh], fs=lfp_recording.sampling_frequency, nperseg=welch_win_length_samples, @@ -194,13 +192,12 @@ def _get_rms(rec): spec_density[:, iCh] = Pxx spec_density = spec_density[:, channel_inds] # only keep channels that were used for spike sorting spec_density = spec_density.astype(np.float32) - f = f.astype(np.float32) - assert spec_density.shape[0] == len(f) + freqs = freqs.astype(np.float32) np.save( os.path.join(output_folder, "_iblqc_ephysSpectralDensityLF.power.npy"), spec_density, ) - np.save(os.path.join(output_folder, "_iblqc_ephysSpectralDensityLF.freqs.npy"), f) + np.save(os.path.join(output_folder, "_iblqc_ephysSpectralDensityLF.freqs.npy"), freqs) ### Save spike info ### From 74a813db5ed9edde0cfdc6c65af7d1f0961cbf3f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 12 Apr 2025 13:35:15 +0000 Subject: [PATCH 10/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/exporters/to_ibl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index 9d60eadb4c..7596230795 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -104,7 +104,7 @@ def export_to_ibl( # Check in case user pre-calculated a small set of qm's that aren't enough for IBL required_qms = ["amplitude_median", "isi_violation", "amplitude_cutoff"] qm = analyzer.get_extension("quality_metrics").get_data() - qms_to_compute = [metric for metric in required_qms if metric not in qm] + qms_to_compute = [metric for metric in required_qms if metric not in qm] analyzer.compute("quality_metrics", metric_names=qms_to_compute, verbose=verbose) # # Start by just exporting to phy From c0941958790bb018697180ca1c62547caa1b7ef5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 23 Apr 2025 16:57:36 +0200 Subject: [PATCH 11/13] parallel rms, remove export to phy, clean up --- src/spikeinterface/exporters/tests/common.py | 5 +- .../exporters/tests/test_export_to_ibl.py | 108 ++++- src/spikeinterface/exporters/to_ibl.py | 420 ++++++++++-------- 3 files changed, 328 insertions(+), 205 deletions(-) diff --git a/src/spikeinterface/exporters/tests/common.py b/src/spikeinterface/exporters/tests/common.py index 78a9c82860..2b191bc1e3 100644 --- a/src/spikeinterface/exporters/tests/common.py +++ b/src/spikeinterface/exporters/tests/common.py @@ -45,7 +45,10 @@ def make_sorting_analyzer(sparse=True, with_group=False): sorting_analyzer.compute("noise_levels") sorting_analyzer.compute("principal_components") sorting_analyzer.compute("template_similarity") - sorting_analyzer.compute("quality_metrics", metric_names=["snr"]) + sorting_analyzer.compute( + "quality_metrics", metric_names=["snr", "amplitude_median", "isi_violation", "amplitude_cutoff"] + ) + sorting_analyzer.compute(["spike_amplitudes", "spike_locations"]) return sorting_analyzer diff --git a/src/spikeinterface/exporters/tests/test_export_to_ibl.py b/src/spikeinterface/exporters/tests/test_export_to_ibl.py index c5905c12de..d44b33250b 100644 --- a/src/spikeinterface/exporters/tests/test_export_to_ibl.py +++ b/src/spikeinterface/exporters/tests/test_export_to_ibl.py @@ -1,7 +1,6 @@ -import shutil - -import numpy as np +import pytest +from spikeinterface.preprocessing import bandpass_filter, decimate from spikeinterface.exporters import export_to_ibl from spikeinterface.exporters.tests.common import ( @@ -9,21 +8,108 @@ sorting_analyzer_sparse_for_export, ) +required_output_files = [ + "spikes.times.npy", + "spikes.clusters.npy", + "spikes.depths.npy", + "spikes.amps.npy", + "clusters.waveforms.npy", + "clusters.peakToTrough.npy", + "clusters.channels.npy", + "clusters.metrics.csv", + "channels.localCoordinates.npy", + "channels.rawInd.npy", +] +ap_output_files = ["_iblqc_ephysTimeRmsAP.rms.npy", "_iblqc_ephysTimeRmsAP.timestamps.npy"] +lfp_output_files = [ + "_iblqc_ephysTimeRmsLF.rms.npy", + "_iblqc_ephysTimeRmsLF.timestamps.npy", + "_iblqc_ephysSpectralDensityLF.power.npy", + "_iblqc_ephysSpectralDensityLF.freqs.npy", +] + +good_units_query = "amplitude_median < -30" + -def test_export_to_ibl(sorting_analyzer_sparse_for_export, create_cache_folder): +def test_export_ap_to_ibl(sorting_analyzer_sparse_for_export, create_cache_folder): cache_folder = create_cache_folder - output_folder = cache_folder / "ibl_output" - for f in (output_folder,): - if f.is_dir(): - shutil.rmtree(f) + output_folder = cache_folder / "ibl_ap_output" + sorting_analyzer = sorting_analyzer_sparse_for_export + # AP, but no LFP export_to_ibl( - sorting_analyzer_sparse_for_export, + sorting_analyzer, output_folder, - lfp_recording=sorting_analyzer_sparse_for_export.recording, + # good_units_query=good_units_query, + verbose=True, + n_jobs=-1, ) + for f in required_output_files: + assert (output_folder / f).exists(), f"Missing file: {f}" + for f in ap_output_files: + assert (output_folder / f).exists(), f"Missing file: {f}" + for f in lfp_output_files: + assert not (output_folder / f).exists(), f"Unexpected file: {f}" + + +def test_export_recordingless_to_ibl(sorting_analyzer_sparse_for_export, create_cache_folder): + cache_folder = create_cache_folder + output_folder = cache_folder / "ibl_recordingless_output" + + sorting_analyzer = sorting_analyzer_sparse_for_export + recording = sorting_analyzer.recording + sorting_analyzer._recording = None + + # AP, but no LFP + export_to_ibl(sorting_analyzer_sparse_for_export, output_folder, good_units_query=good_units_query, n_jobs=-1) + for f in required_output_files: + assert (output_folder / f).exists(), f"Missing file: {f}" + for f in ap_output_files: + assert not (output_folder / f).exists(), f"Missing file: {f}" + for f in lfp_output_files: + assert not (output_folder / f).exists(), f"Unexpected file: {f}" + + sorting_analyzer._recording = recording + + +def test_export_lfp_to_ibl(sorting_analyzer_sparse_for_export, create_cache_folder): + cache_folder = create_cache_folder + output_folder = cache_folder / "ibl_lfp_output" + + sorting_analyzer = sorting_analyzer_sparse_for_export + recording = sorting_analyzer.recording + recording_lfp = bandpass_filter(recording, freq_min=0.5, freq_max=300) + recording_lfp = decimate(recording_lfp, 10) + # LFP, but no AP + export_to_ibl( + sorting_analyzer, output_folder, lfp_recording=recording_lfp, good_units_query=good_units_query, n_jobs=-1 + ) + for f in required_output_files: + assert (output_folder / f).exists(), f"Missing file: {f}" + for f in ap_output_files: + assert (output_folder / f).exists(), f"Unexpected file: {f}" + for f in lfp_output_files: + assert (output_folder / f).exists(), f"Missing file: {f}" + + +def test_missing_info(sorting_analyzer_sparse_for_export, create_cache_folder): + cache_folder = create_cache_folder + output_folder = cache_folder / "ibl_missing_info_output" + + sorting_analyzer = sorting_analyzer_sparse_for_export + + # missing metrics + good_units_query = "rp_violations < 0.2" + + with pytest.raises(ValueError, match="Missing required quality metrics"): + export_to_ibl(sorting_analyzer, output_folder, good_units_query=good_units_query, n_jobs=-1) + + sorting_analyzer.delete_extension("spike_amplitudes") + + with pytest.raises(ValueError, match="Missing required extension"): + export_to_ibl(sorting_analyzer, output_folder, n_jobs=-1) if __name__ == "__main__": sorting_analyzer = make_sorting_analyzer(sparse=True) - test_export_to_ibl(sorting_analyzer) + test_export_ap_to_ibl(sorting_analyzer) diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index 7596230795..456d0974d1 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -1,38 +1,34 @@ from __future__ import annotations from importlib.util import find_spec -import os +import re import shutil import warnings from pathlib import Path -from typing import Optional import numpy as np -from tqdm.auto import tqdm -from spikeinterface.core import ChannelSparsity, SortingAnalyzer -from spikeinterface.core.job_tools import divide_segment_into_chunks +from spikeinterface.core import SortingAnalyzer, BaseRecording, get_random_data_chunks +from spikeinterface.core.job_tools import fix_job_kwargs, ChunkRecordingExecutor, _shared_job_kwargs_doc from spikeinterface.core.template_tools import get_template_extremum_channel -from spikeinterface.exporters import ( - export_to_phy, -) +from spikeinterface.exporters import export_to_phy def export_to_ibl( - analyzer: SortingAnalyzer, + sorting_analyzer: SortingAnalyzer, output_folder: str | Path, - lfp_recording=None, - rms_win_length_sec=3, + lfp_recording: BaseRecording | None = None, + rms_win_length_s=3, welch_win_length_samples=2**14, - total_secs_spec_dens=100, - only_ibl_specific_steps=False, - sparsity: Optional[ChannelSparsity] = None, + psd_chunk_duration_s=1, + psd_num_chunks=100, + good_units_query: str | None = "amplitude_median < -40 and isi_violations_ratio < 0.5 and amplitude_cutoff < 0.2", remove_if_exists: bool = False, verbose: bool = True, **job_kwargs, ): """ - Exports a sorting analyzer to the IBL gui format (similar to the Phy format with some extras). + Exports a sorting analyzer to the IBL GUI format (similar to the Phy format with some extras). Parameters ---------- @@ -41,18 +37,16 @@ def export_to_ibl( Should also contain the pre-processed recording to use for AP-band data. output_folder: str | Path The output folder for the exports. - lfp_recording: Any SI Recording object, default None + lfp_recording: BaseRecording | None, default: None The pre-processed recording to use for LFP data. If None, the LFP data is not exported. - rms_win_length_sec: float, default: 3 + rms_win_length_s: float, default: 3 The window length in seconds for the RMS calculation (on the LFP data). welch_win_length_samples: int, default: 2^14 The window length in samples for the Welch spectral density computation (on the LFP data). - total_secs_spec_dens: int, default: 100 - The total number of seconds to use for the spectral density calculation. - only_ibl_specific_steps: bool, default: False - If True, only the IBL specific steps are run (i.e. skips calling `export_to_phy`) - sparsity: ChannelSparsity or None, default: None - The sparsity object (currently only respected for phy part of the export) + psd_chunk_duration_s: float, default: 1 + The chunk duration in seconds for the spectral density calculation (on the LFP data). + psd_num_chunks: int, default: 100 + The number of chunks to use for the spectral density calculation (on the LFP data). remove_if_exists: bool, default: False If True and "output_folder" exists, it is removed and overwritten verbose: bool, default: True @@ -65,213 +59,253 @@ def export_to_ibl( else: from scipy.signal import welch + if sorting_analyzer.get_num_segments() != 1: + raise ValueError("The export to IBL format only supports a single segment.") + # Output folder checks - if isinstance(output_folder, str): - output_folder = Path(output_folder) output_folder = Path(output_folder).absolute() if output_folder.is_dir(): if remove_if_exists: shutil.rmtree(output_folder) else: raise FileExistsError(f"{output_folder} already exists") - else: - pass - # don't make the output dir yet, b/c export_to_phy will do that for us. if verbose: print("Exporting recording to IBL format...") # Compute any missing extensions - if analyzer.format == "memory": - available_extension_names = list(analyzer.extensions.keys()) - else: - available_extension_names = analyzer.get_saved_extension_names() - required_exts = [ + required_extensions = [ "templates", - "template_similarity", - "spike_locations", - "noise_levels", + "spike_amplitudes", "quality_metrics", ] - for ext in required_exts: - if ext not in available_extension_names: - if ext == "quality_metrics": - kwargs = {"skip_pc_metrics": True} - else: - kwargs = {} - analyzer.compute(ext, verbose=verbose, **kwargs) + for ext in required_extensions: + if not sorting_analyzer.has_extension(ext): + raise ValueError(f"Missing required extension: {ext}. Please compute it before exporting to IBL format.") # Check in case user pre-calculated a small set of qm's that aren't enough for IBL - required_qms = ["amplitude_median", "isi_violation", "amplitude_cutoff"] - qm = analyzer.get_extension("quality_metrics").get_data() - qms_to_compute = [metric for metric in required_qms if metric not in qm] - analyzer.compute("quality_metrics", metric_names=qms_to_compute, verbose=verbose) - - # # Start by just exporting to phy - if not only_ibl_specific_steps: - if verbose: - print("Doing phy-like export...") - export_to_phy( - analyzer, - output_folder, - compute_amplitudes=True, - compute_pc_features=False, - sparsity=sparsity, - copy_binary=False, - template_mode="median", - verbose=verbose, - use_relative_path=False, - **job_kwargs, - ) + if good_units_query is not None: + quality_metrics_in_query = re.split(">|<|>=|<=|==|and", good_units_query)[::2] + required_qms = [qm_name.strip() for qm_name in quality_metrics_in_query] + qm = sorting_analyzer.get_extension("quality_metrics").get_data() + missing_metrics = [] + for qm_name in required_qms: + if qm_name not in qm.columns: + missing_metrics.append(qm_name) + if len(missing_metrics) > 0: + raise ValueError( + f"Missing required quality metrics: {missing_metrics}. Please compute it before exporting to IBL format." + ) # Make sure output dir exists, in case user skips export_to_phy if not output_folder.is_dir(): - os.makedirs(output_folder) + output_folder.mkdir(parents=True, exist_ok=True) - if verbose: - print("Running IBL-specific steps...") + ### Save spikes info ### + extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + spikes = sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_indices) + + # spikes.clusters + np.save(output_folder / "spikes.clusters.npy", spikes["unit_index"].astype("int32")) + + # spike depths + if sorting_analyzer.has_extension("spike_locations"): + spike_locations = sorting_analyzer.get_extension("spike_locations").get_data() + spike_depths = spike_locations["y"] + else: + # we use the extremum channel depth for each spike + spike_depths = sorting_analyzer.get_channel_locations()[:, 1][spikes["channel_index"]] + np.save(output_folder / "spikes.depths.npy", spike_depths.astype("float32")) + + # spike times + spike_sample_indices = spikes["sample_index"] + if sorting_analyzer.has_recording() and sorting_analyzer.recording.has_time_vector(): + spike_times = sorting_analyzer.recording.get_times()[spike_sample_indices] + else: + spike_times = spike_sample_indices / sorting_analyzer.sampling_frequency + np.save(output_folder / "spikes.times.npy", spike_times.astype("float64")) + + # spike amps + amps = sorting_analyzer.get_extension("spike_amplitudes").get_data() + amps_positive_in_V = -amps * 1e-6 + np.save(output_folder / "spikes.amps.npy", amps_positive_in_V.astype("float32")) + + ### Save clusters info ### + + # templates + templates = sorting_analyzer.get_extension("templates").get_data() + np.save(output_folder / "clusters.waveforms.npy", templates) + + # cluster channels + extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + cluster_channels = np.array(list(extremum_channel_indices.values()), dtype="int32") + np.save(output_folder / "clusters.channels.npy", cluster_channels) + + # peak-to-trough durations + + # if template_metrics are already computed, use them to get the peak-to-trough durations + peak_to_trough_durations = None + if sorting_analyzer.has_extension("template_metrics"): + template_metrics = sorting_analyzer.get_extension("template_metrics").get_data() + if "peak_to_valley" in template_metrics.columns: + peak_to_trough_durations = template_metrics["peak_to_valley"].values + + # if not, we will compute them ourselves + if peak_to_trough_durations is None: + peak_to_trough_durations = [] + # get the channel index of the max amplitude for each cluster + for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids): + template = templates[unit_index, :, :] + extremum_channel_index = extremum_channel_indices[unit_id] + peak_waveform = template[:, extremum_channel_index] + peak_to_trough = (np.argmax(peak_waveform) - np.argmin(peak_waveform)) / sorting_analyzer.sampling_frequency + peak_to_trough_durations.append(peak_to_trough) + peak_to_trough_durations = np.array(peak_to_trough_durations) + np.save(output_folder / "clusters.peakToTrough.npy", peak_to_trough_durations) + + # quality metrics + qm = sorting_analyzer.get_extension("quality_metrics") + qm_data = qm.get_data() + qm_data.index.name = "cluster_id" + qm_data["cluster_id.1"] = qm_data.index.values + + if good_units_query is None: + qm_data["label"] = 1 + else: + good_units = qm_data.query(good_units_query) + good_units_indices = good_units.index.values + labels = np.zeros(len(qm_data), dtype="int32") + qm_data["label"] = labels + qm_data.loc[good_units_indices, "label"] = 1 + qm_data.to_csv(output_folder / "clusters.metrics.csv") + + ### Save channels info ### + + # channel positions + channel_positions = sorting_analyzer.get_channel_locations() + np.save(output_folder / "channels.localCoordinates.npy", channel_positions) + + # channel indices + np.save(output_folder / "channels.rawInd.npy", np.arange(sorting_analyzer.get_num_channels(), dtype="int32")) # Now we need to add the extra IBL specific files # See here for docs on the format: https://github.com/int-brain-lab/iblapps/wiki/3.-Overview-of-datasets#input-histology-data - - # Subset channels in case some were excluded from spike sorting - (channel_inds,) = np.isin(analyzer.recording.channel_ids, analyzer.channel_ids).nonzero() - - # TODO: put this into a chunk extractor - def _get_rms(rec): - chunk_nframes = int(rms_win_length_sec * rec.sampling_frequency) - chunks = divide_segment_into_chunks(rec.get_num_samples(), chunk_nframes) - chunk_rms = np.zeros((len(chunks), rec.get_num_channels())) - chunk_start_times = np.zeros((len(chunks),)) - for iChunk, (start_frame, stop_frame) in enumerate(tqdm(chunks)): - traces = rec.get_traces(start_frame=start_frame, end_frame=stop_frame) - chunk_rms[iChunk, :] = np.sqrt(np.mean(traces**2, axis=0)) - chunk_start_times[iChunk] = start_frame / rec.sampling_frequency - chunk_rms = chunk_rms[:, channel_inds] - chunk_rms = chunk_rms.astype(np.float32) - chunk_start_times = chunk_start_times.astype(np.float32) - return chunk_rms, chunk_start_times - - if analyzer.has_recording(): - # Get RMS for the AP data. We will use a window of length rms_win_length_sec seconds slid over the entire recording. - ap_rec = analyzer.recording - if ap_rec.get_num_segments() != 1: - warnings.warn("Found ap recording with more than one segment, only using initial segment.") - ap_rec = ap_rec[0] - chunk_rms, chunk_start_times = _get_rms(ap_rec) - np.save(os.path.join(output_folder, "_iblqc_ephysTimeRmsAP.rms.npy"), chunk_rms) - np.save( - os.path.join(output_folder, "_iblqc_ephysTimeRmsAP.timestamps.npy"), - chunk_start_times, + if sorting_analyzer.has_recording(): + # Get RMS for the preprocessed (AP) data. We will use a window of length rms_win_length_s seconds slid over the entire recording. + if verbose: + print("Computing AP RMS") + recording_ap = sorting_analyzer.recording + job_kwargs_ = job_kwargs.copy() + job_kwargs_["chunk_duration"] = f"{rms_win_length_s}s" + rms_preprocessed, rms_times = compute_rms( + recording_ap, + verbose=verbose, + **job_kwargs_, ) + np.save(output_folder / "_iblqc_ephysTimeRmsAP.rms.npy", rms_preprocessed) + np.save(output_folder / "_iblqc_ephysTimeRmsAP.timestamps.npy", rms_times) elif verbose: print("No recording data found in the SortingAnalyzer, skipping AP RMS calculation.") if lfp_recording is not None: - # Get RMS for the LFP data. - if lfp_recording.get_num_segments() != 1: - warnings.warn("Found lfp recording with more than one segment, only using initial segment.") - lfp_recording = lfp_recording[0] - chunk_rms, chunk_start_times = _get_rms(lfp_recording) - np.save(os.path.join(output_folder, "_iblqc_ephysTimeRmsLF.rms.npy"), chunk_rms) - np.save( - os.path.join(output_folder, "_iblqc_ephysTimeRmsLF.timestamps.npy"), - chunk_start_times, - ) + # Get RMS for the LFP data + if verbose: + print("Computing LFP RMS") + job_kwargs_ = job_kwargs.copy() + job_kwargs_["chunk_duration"] = f"{rms_win_length_s}s" + rms_lfp, rms_times = compute_rms(lfp_recording, verbose=verbose, **job_kwargs_) + np.save(output_folder / "_iblqc_ephysTimeRmsLF.rms.npy", rms_lfp) + np.save(output_folder / "_iblqc_ephysTimeRmsLF.timestamps.npy", rms_times) # Get spectral density on a snippet of LFP data - end_frame = int(total_secs_spec_dens * lfp_recording.sampling_frequency) - traces = lfp_recording.get_traces(start_frame=0, end_frame=end_frame) # time x channels - spec_density = np.zeros((welch_win_length_samples // 2 + 1, traces.shape[1])) - for iCh in range(traces.shape[1]): + if verbose: + print("Computing LFP PSD") + lfp_sample_data = get_random_data_chunks( + lfp_recording, + num_chunks_per_segment=psd_num_chunks, + chunk_duration=f"{psd_chunk_duration_s}s", + return_scaled=True, + concatenated=True, + ) + psd = np.zeros((welch_win_length_samples // 2 + 1, lfp_sample_data.shape[1]), dtype=np.float32) + for i_channel in range(lfp_sample_data.shape[1]): freqs, Pxx = welch( - traces[:, iCh], + lfp_sample_data[:, i_channel], fs=lfp_recording.sampling_frequency, nperseg=welch_win_length_samples, ) - spec_density[:, iCh] = Pxx - spec_density = spec_density[:, channel_inds] # only keep channels that were used for spike sorting - spec_density = spec_density.astype(np.float32) + psd[:, i_channel] = Pxx freqs = freqs.astype(np.float32) - np.save( - os.path.join(output_folder, "_iblqc_ephysSpectralDensityLF.power.npy"), - spec_density, - ) - np.save(os.path.join(output_folder, "_iblqc_ephysSpectralDensityLF.freqs.npy"), freqs) - - ### Save spike info ### - - spike_locations = analyzer.get_extension("spike_locations").get_data() - spike_depths = spike_locations["y"] - - # convert clusters and squeeze - clusters = np.load(output_folder / "spike_clusters.npy") - np.save(output_folder / "spike_clusters.npy", np.squeeze(clusters.astype("uint32"))) - - # convert times and squeeze - times = np.load(output_folder / "spike_times.npy") - np.save(output_folder / "spike_times.npy", np.squeeze(times / analyzer.sampling_frequency).astype("float64")) - - # convert amplitudes and squeeze - amps = np.load(output_folder / "amplitudes.npy") - np.save(output_folder / "amplitudes.npy", np.squeeze(-amps / 1e6).astype("float64")) - - # save depths and channel inds - np.save(output_folder / "spike_depths.npy", spike_depths) - np.save(output_folder / "channel_inds.npy", np.arange(len(channel_inds), dtype="int")) - - # # save templates - cluster_channels = [] - cluster_peak_to_trough = [] - cluster_waveforms = [] - templates = analyzer.get_extension("templates").get_data() - extremum_channel_indices = get_template_extremum_channel(analyzer, outputs="index") - - for unit_index, unit_id in enumerate(analyzer.unit_ids): - waveform = templates[unit_index, :, :] - extremum_channel_index = extremum_channel_indices[unit_id] - peak_waveform = waveform[:, extremum_channel_index] - peak_to_trough = (np.argmax(peak_waveform) - np.argmin(peak_waveform)) / analyzer.sampling_frequency - # cluster_channels.append(int(channel_locs[extremum_channel_index, 1] / 10)) # ??? fails for odd nums of units - cluster_channels.append( - extremum_channel_index - ) # see: https://github.com/SpikeInterface/spikeinterface/issues/2843#issuecomment-2148164870 - cluster_peak_to_trough.append(peak_to_trough) - cluster_waveforms.append(waveform) - - np.save(output_folder / "cluster_peak_to_trough.npy", np.array(cluster_peak_to_trough)) - np.save(output_folder / "cluster_waveforms.npy", np.stack(cluster_waveforms)) - np.save(output_folder / "cluster_channels.npy", np.array(cluster_channels)) - - # rename files from this func and the phy export func - _FILE_RENAMES = [ # file_in, file_out - ("channel_positions.npy", "channels.localCoordinates.npy"), - ("channel_inds.npy", "channels.rawInd.npy"), - ("cluster_peak_to_trough.npy", "clusters.peakToTrough.npy"), - ("cluster_channels.npy", "clusters.channels.npy"), - ("cluster_waveforms.npy", "clusters.waveforms.npy"), - ("spike_clusters.npy", "spikes.clusters.npy"), - ("amplitudes.npy", "spikes.amps.npy"), - ("spike_depths.npy", "spikes.depths.npy"), - ("spike_times.npy", "spikes.times.npy"), - ] + np.save(output_folder / "_iblqc_ephysSpectralDensityLF.power.npy", psd) + np.save(output_folder / "_iblqc_ephysSpectralDensityLF.freqs.npy", freqs) - for names in _FILE_RENAMES: - old_name = output_folder / names[0] - new_name = output_folder / names[1] - os.rename(old_name, new_name) - # save quality metrics - qm = analyzer.get_extension("quality_metrics") - qm_data = qm.get_data() - qm_data.index.name = "cluster_id" - qm_data["cluster_id.1"] = qm_data.index.values - amplitude_sign_coef = -1 if analyzer.get_extension("spike_amplitudes").params["peak_sign"] == "neg" else 1 +def compute_rms( + recording: BaseRecording, + verbose: bool = False, + **job_kwargs, +): + """ + Compute the RMS of a recording in chunks. - good_ibl = ( # rough, slightly looser estimate of ibl standards - ((amplitude_sign_coef * qm_data["amplitude_median"]) > 40) - & (qm_data["isi_violations_ratio"] < 0.5) - & (qm_data["amplitude_cutoff"] < 0.2) + Parameters + ---------- + recording: BaseRecording + The recording object to compute the RMS for. + {} + """ + job_kwargs = fix_job_kwargs(job_kwargs) + + # use executor (loop or workers) + func = _compute_rms_chunk + init_func = _init_rms_worker + init_args = (recording,) + executor = ChunkRecordingExecutor( + recording, + func, + init_func, + init_args, + job_name="compute_rms", + verbose=verbose, + handle_returns=True, + **job_kwargs, ) - qm_data["label"] = good_ibl.astype("int") - qm_data.to_csv(output_folder / "clusters.metrics.csv") + results = executor.run() + + rms_values = np.zeros((len(results), recording.get_num_channels())) + rms_times = np.zeros((len(results))) + + for i, result in enumerate(results): + rms_values[i, :], rms_times[i] = result + + return rms_values, rms_times + + +def _init_rms_worker(recording): + # create a local dict per worker + worker_ctx = {} + worker_ctx["recording"] = recording + worker_ctx["times"] = recording.get_times() + return worker_ctx + + +def _compute_rms_chunk(segment_index, start_frame, end_frame, worker_ctx): + # recover variables of the worker + recording = worker_ctx["recording"] + times = worker_ctx["times"] + + traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) + rms = np.sqrt(np.mean(traces**2, axis=0)) + # get the middle time of the chunk + if end_frame < recording.get_num_samples() - 1: + middle_frame = (start_frame + end_frame) // 2 + else: + # if we are at the end of the recording, use the middle point of the last chunk + middle_frame = (start_frame + recording.get_num_samples() - 1) // 2 + # get the time of the middle frame + rms_time = times[middle_frame] + + return rms, rms_time + + +compute_rms.__doc__ = compute_rms.__doc__.format(_shared_job_kwargs_doc) From a4a1c8094c0b70171330abb235ffdbafd78f991f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 23 Apr 2025 17:12:54 +0200 Subject: [PATCH 12/13] renamed export_to_ibl_gui and add docs and api ref --- doc/api.rst | 1 + doc/modules/exporters.rst | 38 +++++++++++++++++-- src/spikeinterface/exporters/__init__.py | 2 +- .../exporters/tests/test_export_to_ibl.py | 10 ++--- src/spikeinterface/exporters/to_ibl.py | 4 +- src/spikeinterface/exporters/to_phy.py | 2 +- 6 files changed, 45 insertions(+), 12 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 7d20ed7c19..68bd6b70cc 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -335,6 +335,7 @@ spikeinterface.exporters .. automodule:: spikeinterface.exporters .. autofunction:: export_to_phy + .. autofunction:: export_to_ibl_gui .. autofunction:: export_report diff --git a/doc/modules/exporters.rst b/doc/modules/exporters.rst index 7f8eeeb19e..3f2aa2dca5 100644 --- a/doc/modules/exporters.rst +++ b/doc/modules/exporters.rst @@ -25,7 +25,6 @@ The input of the :py:func:`~spikeinterface.exporters.export_to_phy` is a :code:` .. code-block:: python import spikeinterface as si # core module only - from spikeinterface.postprocessing import compute_spike_amplitudes, compute_principal_components from spikeinterface.exporters import export_to_phy # the waveforms are sparse so it is faster to export to phy @@ -40,6 +39,41 @@ The input of the :py:func:`~spikeinterface.exporters.export_to_phy` is a :code:` export_to_phy(sorting_analyzer=sorting_analyzer, output_folder='path/to/phy_folder') +Export to IBL GUI +----------------- + +The :py:func:`~spikeinterface.exporters.export_to_ibl_gui` function allows you to use the +`IBL GUI `_ for probe alignment. + +The IBL GUI can also be installed as a standalone app using `this fork `_ from the Allen Institute. + +The input of the :py:func:`~spikeinterface.exporters.export_to_ibl_gui` is a :code:`SortingAnalyzer` object. + +.. code-block:: python + + import spikeinterface as si # core module only + import spikeinterface.preprocessing as spre + from spikeinterface.exporters import export_to_ibl_gui + + sorting_analyzer = si.create_sorting_analyzer(sorting=sorting, recording=recording) + + # we need to compute some required extensions + sorting_analyzer.compute(['random_spikes', 'templates', 'spike_amplitudes', 'spike_locations', 'quality_metrics']) + # note that spike_locations are optional, but recommended to compute accurate spike depths + + # optionally, we can pass an LFP recording to compute RMS/PSD in the LFP band + recording_lfp = spre.bandpass_filter(recording, freq_min=1, freq_max=300) + # we can also decimate the LFP to speed up the process + recording_lfp = spre.decimate(recording_lfp, 10) + + # the export process is fast because everything is pre-computed + export_to_ibl_gui( + sorting_analyzer=sorting_analyzer, + output_folder='path/to/ibl_folder', + recording_lfp=recording_lfp, + n_jobs=-1 + ) + Export a spike sorting report ----------------------------- @@ -68,8 +102,6 @@ with many units! .. code-block:: python import spikeinterface as si # core module only - from spikeinterface.postprocessing import compute_spike_amplitudes, compute_correlograms - from spikeinterface.qualitymetrics import compute_quality_metrics from spikeinterface.exporters import export_report diff --git a/src/spikeinterface/exporters/__init__.py b/src/spikeinterface/exporters/__init__.py index 1bec3e5cab..dd0d7b0755 100644 --- a/src/spikeinterface/exporters/__init__.py +++ b/src/spikeinterface/exporters/__init__.py @@ -1,3 +1,3 @@ from .to_phy import export_to_phy from .report import export_report -from .to_ibl import export_to_ibl +from .to_ibl import export_to_ibl_gui diff --git a/src/spikeinterface/exporters/tests/test_export_to_ibl.py b/src/spikeinterface/exporters/tests/test_export_to_ibl.py index d44b33250b..0fd50c0d26 100644 --- a/src/spikeinterface/exporters/tests/test_export_to_ibl.py +++ b/src/spikeinterface/exporters/tests/test_export_to_ibl.py @@ -1,7 +1,7 @@ import pytest from spikeinterface.preprocessing import bandpass_filter, decimate -from spikeinterface.exporters import export_to_ibl +from spikeinterface.exporters import export_to_ibl_gui from spikeinterface.exporters.tests.common import ( make_sorting_analyzer, @@ -37,7 +37,7 @@ def test_export_ap_to_ibl(sorting_analyzer_sparse_for_export, create_cache_folde sorting_analyzer = sorting_analyzer_sparse_for_export # AP, but no LFP - export_to_ibl( + export_to_ibl_gui( sorting_analyzer, output_folder, # good_units_query=good_units_query, @@ -61,7 +61,7 @@ def test_export_recordingless_to_ibl(sorting_analyzer_sparse_for_export, create_ sorting_analyzer._recording = None # AP, but no LFP - export_to_ibl(sorting_analyzer_sparse_for_export, output_folder, good_units_query=good_units_query, n_jobs=-1) + export_to_ibl_gui(sorting_analyzer_sparse_for_export, output_folder, good_units_query=good_units_query, n_jobs=-1) for f in required_output_files: assert (output_folder / f).exists(), f"Missing file: {f}" for f in ap_output_files: @@ -102,12 +102,12 @@ def test_missing_info(sorting_analyzer_sparse_for_export, create_cache_folder): good_units_query = "rp_violations < 0.2" with pytest.raises(ValueError, match="Missing required quality metrics"): - export_to_ibl(sorting_analyzer, output_folder, good_units_query=good_units_query, n_jobs=-1) + export_to_ibl_gui(sorting_analyzer, output_folder, good_units_query=good_units_query, n_jobs=-1) sorting_analyzer.delete_extension("spike_amplitudes") with pytest.raises(ValueError, match="Missing required extension"): - export_to_ibl(sorting_analyzer, output_folder, n_jobs=-1) + export_to_ibl_gui(sorting_analyzer, output_folder, n_jobs=-1) if __name__ == "__main__": diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index 456d0974d1..53efd08b7e 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -14,7 +14,7 @@ from spikeinterface.exporters import export_to_phy -def export_to_ibl( +def export_to_ibl_gui( sorting_analyzer: SortingAnalyzer, output_folder: str | Path, lfp_recording: BaseRecording | None = None, @@ -28,7 +28,7 @@ def export_to_ibl( **job_kwargs, ): """ - Exports a sorting analyzer to the IBL GUI format (similar to the Phy format with some extras). + Exports a sorting analyzer to the format required by the `IBL alignment GUI `_. Parameters ---------- diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 06041da231..d3a823ce3f 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -35,7 +35,7 @@ def export_to_phy( **job_kwargs, ): """ - Exports a waveform extractor to the phy template-gui format. + Exports a sorting analyzer to the phy template-gui format. Parameters ---------- From 9fad937639f581b2e602566e5f244ce743ee9e62 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 23 Apr 2025 18:53:54 +0200 Subject: [PATCH 13/13] Fix typo in tests --- src/spikeinterface/exporters/tests/test_export_to_ibl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/exporters/tests/test_export_to_ibl.py b/src/spikeinterface/exporters/tests/test_export_to_ibl.py index 0fd50c0d26..3b859634df 100644 --- a/src/spikeinterface/exporters/tests/test_export_to_ibl.py +++ b/src/spikeinterface/exporters/tests/test_export_to_ibl.py @@ -81,7 +81,7 @@ def test_export_lfp_to_ibl(sorting_analyzer_sparse_for_export, create_cache_fold recording_lfp = bandpass_filter(recording, freq_min=0.5, freq_max=300) recording_lfp = decimate(recording_lfp, 10) # LFP, but no AP - export_to_ibl( + export_to_ibl_gui( sorting_analyzer, output_folder, lfp_recording=recording_lfp, good_units_query=good_units_query, n_jobs=-1 ) for f in required_output_files: