diff --git a/examples/tutorials/core/plot_4_sorting_analyzer.py b/examples/tutorials/core/plot_4_sorting_analyzer.py index 96b5f57b0d..3b49e35fff 100644 --- a/examples/tutorials/core/plot_4_sorting_analyzer.py +++ b/examples/tutorials/core/plot_4_sorting_analyzer.py @@ -43,9 +43,8 @@ ############################################################################## # Let's now instantiate the recording and sorting objects: -recording = se.MEArecRecordingExtractor(local_path) +recording, sorting = se.read_mearec(local_path) print(recording) -sorting = se.MEArecSortingExtractor(local_path) print(sorting) ############################################################################### diff --git a/examples/tutorials/extractors/plot_1_read_various_formats.py b/examples/tutorials/extractors/plot_1_read_various_formats.py index ef31b1dc76..b1fe277d92 100644 --- a/examples/tutorials/extractors/plot_1_read_various_formats.py +++ b/examples/tutorials/extractors/plot_1_read_various_formats.py @@ -61,7 +61,7 @@ # :py:class:`~spikeinterface.extractors.Spike2RecordingExtractor` object: # -recording = se.Spike2RecordingExtractor(spike2_file_path, stream_id="0") +recording = se.read_spike2(spike2_file_path, stream_id="0") print(recording) ############################################################################## @@ -75,11 +75,6 @@ print(sorting) print(type(sorting)) -############################################################################## -# The :py:func:`~spikeinterface.extractors.read_mearec` function is equivalent to: - -recording = se.MEArecRecordingExtractor(mearec_folder_path) -sorting = se.MEArecSortingExtractor(mearec_folder_path) ############################################################################## # SI objects (:py:class:`~spikeinterface.core.BaseRecording` and :py:class:`~spikeinterface.core.BaseSorting`) diff --git a/src/spikeinterface/extractors/__init__.py b/src/spikeinterface/extractors/__init__.py index 5e550017cc..7ac516e73b 100644 --- a/src/spikeinterface/extractors/__init__.py +++ b/src/spikeinterface/extractors/__init__.py @@ -1,9 +1,64 @@ -from .extractorlist import * +from .extractor_classes import * -from .toy_example import toy_example -from .bids import read_bids +from .toy_example import toy_example as toy_example +from .bids import read_bids as read_bids from .neuropixels_utils import get_neuropixels_channel_groups, get_neuropixels_sample_shifts from .neoextractors import get_neo_num_blocks, get_neo_streams + +from warnings import warn + + +# deprecation of class import idea from neuroconv +# this __getattr__ is only triggered if the normal lookup fails so import +# any of our functions is fine but if someone tries to import a class this raises +# the warning and then returns the "function" version which will look the same +# to the end-user +# to be removed after version 0.105.0 +def __getattr__(extractor_name): + # we need this trick to allow us to use import * for spikeinterface.full + if extractor_name == "__all__": + __all__ = [] + for imp in globals(): + # need to remove a bunch of builtins etc that shouldn't be part of all + if imp[0] != "_" and imp != "warn" and imp != "extractor_name": + __all__.append(imp) + return __all__ + all_extractors = list(recording_extractor_full_dict.values()) + all_extractors += list(sorting_extractor_full_dict.values()) + all_extractors += list(event_extractor_full_dict.values()) + all_extractors += list(snippets_extractor_full_dict.values()) + # special cases because they don't have simple wrappers + # instead a single wrapper maps to multiple classes so we return + # each class to check it + from .neoextractors import ( + MEArecRecordingExtractor, + MEArecSortingExtractor, + OpenEphysBinaryEventExtractor, + OpenEphysBinaryRecordingExtractor, + OpenEphysLegacyRecordingExtractor, + SpikeGLXEventExtractor, + ) + + all_extractors += [ + MEArecRecordingExtractor, + MEArecSortingExtractor, + OpenEphysBinaryEventExtractor, + OpenEphysBinaryRecordingExtractor, + OpenEphysLegacyRecordingExtractor, + SpikeGLXEventExtractor, + ] + for reading_function in all_extractors: + if extractor_name == reading_function.__name__: + dep_msg = ( + "Importing classes at __init__ has been deprecated in favor of only importing function-size wrappers " + "and will be removed in 0.105.0. For developers that prefer working with the class versions of extractors " + "they can be imported from spikeinterface.extractors.extractor_classes" + ) + warn(dep_msg) + return reading_function + # this is necessary for objects that we don't support + # normally this is an ImportError but since this is in the __getattr__ pytest needs an AttributeError + raise AttributeError(f"cannot import name '{extractor_name}' from '{__name__}'") diff --git a/src/spikeinterface/extractors/extractor_classes.py b/src/spikeinterface/extractors/extractor_classes.py new file mode 100644 index 0000000000..8d4aa8ed32 --- /dev/null +++ b/src/spikeinterface/extractors/extractor_classes.py @@ -0,0 +1,202 @@ +from __future__ import annotations + + +# most important extractor are in spikeinterface.core +from spikeinterface.core import ( + BinaryFolderRecording, + BinaryRecordingExtractor, + NumpyRecording, + NpzSortingExtractor, + NumpySorting, + NpySnippetsExtractor, + ZarrRecordingExtractor, + ZarrSortingExtractor, + read_binary, + read_zarr, + read_npz_sorting, + read_npy_snippets, +) + +# sorting/recording/event from neo +from .neoextractors import * + +# non-NEO objects implemented in neo folder +# keep for reference Currently pulling from neoextractor __init__ +# from .neoextractors import NeuroScopeSortingExtractor, MaxwellEventExtractor + +# NWB sorting/recording/event +from .nwbextractors import ( + NwbRecordingExtractor, + NwbSortingExtractor, + NwbTimeSeriesExtractor, + read_nwb, + read_nwb_recording, + read_nwb_sorting, + read_nwb_timeseries, +) + +from .cbin_ibl import CompressedBinaryIblExtractor, read_cbin_ibl +from .iblextractors import IblRecordingExtractor, IblSortingExtractor, read_ibl_recording, read_ibl_sorting +from .mcsh5extractors import MCSH5RecordingExtractor, read_mcsh5 +from .whitematterrecordingextractor import WhiteMatterRecordingExtractor, read_whitematter + +# sorting extractors in relation with a sorter +from .cellexplorersortingextractor import CellExplorerSortingExtractor, read_cellexplorer +from .klustaextractors import KlustaSortingExtractor, read_klusta +from .hdsortextractors import HDSortSortingExtractor, read_hdsort +from .mclustextractors import MClustSortingExtractor, read_mclust +from .waveclustextractors import WaveClusSortingExtractor, read_waveclus +from .yassextractors import YassSortingExtractor, read_yass +from .combinatoextractors import CombinatoSortingExtractor, read_combinato +from .tridesclousextractors import TridesclousSortingExtractor, read_tridesclous +from .spykingcircusextractors import SpykingCircusSortingExtractor, read_spykingcircus +from .herdingspikesextractors import HerdingspikesSortingExtractor, read_herdingspikes +from .mdaextractors import MdaRecordingExtractor, MdaSortingExtractor, read_mda_recording, read_mda_sorting +from .phykilosortextractors import PhySortingExtractor, KiloSortSortingExtractor, read_phy, read_kilosort +from .sinapsrecordingextractors import ( + SinapsResearchPlatformRecordingExtractor, + SinapsResearchPlatformH5RecordingExtractor, + read_sinaps_research_platform, + read_sinaps_research_platform_h5, +) + +# sorting in relation with simulator +from .shybridextractors import ( + SHYBRIDRecordingExtractor, + SHYBRIDSortingExtractor, + read_shybrid_recording, + read_shybrid_sorting, +) + +# snippers +from .waveclussnippetstextractors import WaveClusSnippetsExtractor, read_waveclus_snippets + + +# misc +from .alfsortingextractor import ALFSortingExtractor, read_alf_sorting + + +############################################################################################### +# the following code is necessary for controlling what the end user imports from spikeinterface. +# The strategy has three goals: +# +# * A mapping from the original class to its wrapper (because that's what we want to expose) +# * A mapping from the original class to its wrapper string (because of __all__) +# * A mapping from format to the class wrapper for convenience (exposed to users for ease of use) +# +# To achieve these there goals we do the following: +# +# 1) we line up each class with its wrapper that returns a snakecase version of the class (in some docs called +# the "function" version, although this is just a wrapper of the underlying class) +# 2) we do (1) by creating nested dicts where the key is the original class and the values are a nested dict with +# 3) a "wrapper_class" key which returns the wrapper to be exposed to the end user and +# 4) a "wrapper_string" which is added to the __all__ attribute of the __init__. This is necessary because __all__ +# can only accept a list of strings +# 5) Finally we create dictionaries exposed to the user where we return a formatted file format as a key along +# with the value being the wrapper (see the comment below for examples for this dict) +# +# Note that some formats (e.g. binary and numpy) still use the class format as they aren't read-only (i.e. they +# have no wrapper) + +_recording_extractor_full_dict = { + # core extractors that are returned as classes + BinaryFolderRecording: dict(wrapper_string="BinaryFolderRecording", wrapper_class=BinaryFolderRecording), + BinaryRecordingExtractor: dict(wrapper_string="BinaryRecordingExtractor", wrapper_class=BinaryRecordingExtractor), + ZarrRecordingExtractor: dict(wrapper_string="ZarrRecordingExtractor", wrapper_class=ZarrRecordingExtractor), + # natively implemented in spikeinterface.extractors + NumpyRecording: dict(wrapper_string="NumpyRecording", wrapper_class=NumpyRecording), + SHYBRIDRecordingExtractor: dict(wrapper_string="read_shybrid_recording", wrapper_class=read_shybrid_recording), + MdaRecordingExtractor: dict(wrapper_string="read_mda_recording", wrapper_class=read_mda_recording), + NwbRecordingExtractor: dict(wrapper_string="read_nwb_recording", wrapper_class=read_nwb_recording), + NwbTimeSeriesExtractor: dict(wrapper_string="read_nwb_timeseries", wrapper_class=read_nwb_timeseries), + # others + CompressedBinaryIblExtractor: dict(wrapper_string="read_cbin_ibl", wrapper_class=read_cbin_ibl), + IblRecordingExtractor: dict(wrapper_string="read_ibl_recording", wrapper_class=read_ibl_recording), + MCSH5RecordingExtractor: dict(wrapper_string="read_mcsh5", wrapper_class=read_mcsh5), + SinapsResearchPlatformRecordingExtractor: dict( + wrapper_string="read_sinaps_research_platform", wrapper_class=read_sinaps_research_platform + ), + SinapsResearchPlatformH5RecordingExtractor: dict( + wrapper_string="read_sinaps_research_platform_h5", wrapper_class=read_sinaps_research_platform_h5 + ), + WhiteMatterRecordingExtractor: dict(wrapper_string="read_whitematter", wrapper_class=read_whitematter), +} +_recording_extractor_full_dict.update(neo_recording_extractors_dict) + +_sorting_extractor_full_dict = { + NpzSortingExtractor: dict(wrapper_string="read_npz_sorting", wrapper_class=read_npz_sorting), + ZarrSortingExtractor: dict(wrapper_string="ZarrSortingExtractor", wrapper_class=ZarrSortingExtractor), + NumpySorting: dict(wrapper_string="NumpySorting", wrapper_class=NumpySorting), + # natively implemented in spikeinterface.extractors + MdaSortingExtractor: dict(wrapper_string="read_mda_sorting", wrapper_class=read_mda_sorting), + SHYBRIDSortingExtractor: dict(wrapper_string="read_shybrid_sorting", wrapper_class=read_shybrid_sorting), + ALFSortingExtractor: dict(wrapper_string="read_alf_sorting", wrapper_class=read_alf_sorting), + KlustaSortingExtractor: dict(wrapper_string="read_klusta", wrapper_class=read_klusta), + HDSortSortingExtractor: dict(wrapper_string="read_hdsort", wrapper_class=read_hdsort), + MClustSortingExtractor: dict(wrapper_string="read_mclust", wrapper_class=read_mclust), + WaveClusSortingExtractor: dict(wrapper_string="read_waveclus", wrapper_class=read_waveclus), + YassSortingExtractor: dict(wrapper_string="read_yass", wrapper_class=read_yass), + CombinatoSortingExtractor: dict(wrapper_string="read_combinato", wrapper_class=read_combinato), + TridesclousSortingExtractor: dict(wrapper_string="read_tridesclous", wrapper_class=read_tridesclous), + SpykingCircusSortingExtractor: dict(wrapper_string="read_spykingcircus", wrapper_class=read_spykingcircus), + HerdingspikesSortingExtractor: dict(wrapper_string="read_herdingspikes", wrapper_class=read_herdingspikes), + KiloSortSortingExtractor: dict(wrapper_string="read_kilosort", wrapper_class=read_kilosort), + PhySortingExtractor: dict(wrapper_string="read_phy", wrapper_class=read_phy), + NwbSortingExtractor: dict(wrapper_string="read_nwb_sorting", wrapper_class=read_nwb_sorting), + IblSortingExtractor: dict(wrapper_string="read_ibl_sorting", wrapper_class=read_ibl_sorting), + CellExplorerSortingExtractor: dict(wrapper_string="read_cellexplorer", wrapper_class=read_cellexplorer), +} +_sorting_extractor_full_dict.update(neo_sorting_extractors_dict) + +# events only from neo +_event_extractor_full_dict = neo_event_extractors_dict + +_snippets_extractor_full_dict = { + NpySnippetsExtractor: dict(wrapper_string="read_npy_snippets", wrapper_class=read_npy_snippets), + WaveClusSnippetsExtractor: dict(wrapper_string="read_waveclus_snippets", wrapper_class=read_waveclus_snippets), +} + +############################################################################################################ +# Organize the possible extractors into a user facing format with keys being extractor names +# (e.g. 'intan' , 'kilosort') and values being the appropriate Extractor class returned as its wrapper +# (e.g. IntanRecordingExtractor, KiloSortSortingExtractor) +# An important note is the the formats are returned after performing `.lower()` so a format like +# SpikeGLX will be a key of 'spikeglx' +# for example if we wanted to create a recording from an intan file we could do the following: +# >>> recording = se.recording_extractor_full_dict['intan'](file_path='path/to/data.rhd') + + +recording_extractor_full_dict = { + rec_class.__name__.replace("Recording", "").replace("Extractor", "").lower(): rec_func["wrapper_class"] + for rec_class, rec_func in _recording_extractor_full_dict.items() +} +sorting_extractor_full_dict = { + sort_class.__name__.replace("Sorting", "").replace("Extractor", "").lower(): sort_func["wrapper_class"] + for sort_class, sort_func in _sorting_extractor_full_dict.items() +} +event_extractor_full_dict = { + event_class.__name__.replace("Event", "").replace("Extractor", "").lower(): event_func["wrapper_class"] + for event_class, event_func in _event_extractor_full_dict.items() +} +snippets_extractor_full_dict = { + snippets_class.__name__.replace("Snippets", "").replace("Extractor", "").lower(): snippets_func["wrapper_class"] + for snippets_class, snippets_func in _snippets_extractor_full_dict.items() +} + + +# we only do the functions in the init rather than pull in the classes +__all__ = [func["wrapper_string"] for func in _recording_extractor_full_dict.values()] +__all__ += [func["wrapper_string"] for func in _sorting_extractor_full_dict.values()] +__all__ += [func["wrapper_string"] for func in _event_extractor_full_dict.values()] +__all__ += [func["wrapper_string"] for func in _snippets_extractor_full_dict.values()] +__all__.extend( + [ + "read_nwb", # convenience function for multiple nwb formats + "recording_extractor_full_dict", + "sorting_extractor_full_dict", + "event_extractor_full_dict", + "snippets_extractor_full_dict", + "read_binary", # convenience function for binary formats + "read_zarr", + ] +) diff --git a/src/spikeinterface/extractors/extractorlist.py b/src/spikeinterface/extractors/extractorlist.py deleted file mode 100644 index 6e872cde58..0000000000 --- a/src/spikeinterface/extractors/extractorlist.py +++ /dev/null @@ -1,146 +0,0 @@ -from __future__ import annotations - -from typing import Type - -# most important extractor are in spikeinterface.core -from spikeinterface.core import ( - BaseRecording, - BaseSorting, - BinaryFolderRecording, - BinaryRecordingExtractor, - NumpyRecording, - NpzSortingExtractor, - NumpySorting, - NpySnippetsExtractor, - ZarrRecordingExtractor, - ZarrSortingExtractor, - read_binary, - read_zarr, - read_npz_sorting, -) - -# sorting/recording/event from neo -from .neoextractors import * - -# non-NEO objects implemented in neo folder -from .neoextractors import NeuroScopeSortingExtractor, MaxwellEventExtractor - -# NWB sorting/recording/event -from .nwbextractors import ( - NwbRecordingExtractor, - NwbSortingExtractor, - NwbTimeSeriesExtractor, - read_nwb, - read_nwb_recording, - read_nwb_sorting, - read_nwb_timeseries, -) - -from .cbin_ibl import CompressedBinaryIblExtractor, read_cbin_ibl -from .iblextractors import IblRecordingExtractor, IblSortingExtractor, read_ibl_recording, read_ibl_sorting -from .mcsh5extractors import MCSH5RecordingExtractor, read_mcsh5 -from .whitematterrecordingextractor import WhiteMatterRecordingExtractor - -# sorting extractors in relation with a sorter -from .cellexplorersortingextractor import CellExplorerSortingExtractor, read_cellexplorer -from .klustaextractors import KlustaSortingExtractor, read_klusta -from .hdsortextractors import HDSortSortingExtractor, read_hdsort -from .mclustextractors import MClustSortingExtractor, read_mclust -from .waveclustextractors import WaveClusSortingExtractor, read_waveclus -from .yassextractors import YassSortingExtractor, read_yass -from .combinatoextractors import CombinatoSortingExtractor, read_combinato -from .tridesclousextractors import TridesclousSortingExtractor, read_tridesclous -from .spykingcircusextractors import SpykingCircusSortingExtractor, read_spykingcircus -from .herdingspikesextractors import HerdingspikesSortingExtractor, read_herdingspikes -from .mdaextractors import MdaRecordingExtractor, MdaSortingExtractor, read_mda_recording, read_mda_sorting -from .phykilosortextractors import PhySortingExtractor, KiloSortSortingExtractor, read_phy, read_kilosort -from .sinapsrecordingextractors import ( - SinapsResearchPlatformRecordingExtractor, - SinapsResearchPlatformH5RecordingExtractor, - read_sinaps_research_platform, - read_sinaps_research_platform_h5, -) - -# sorting in relation with simulator -from .shybridextractors import ( - SHYBRIDRecordingExtractor, - SHYBRIDSortingExtractor, - read_shybrid_recording, - read_shybrid_sorting, -) - -# snippers -from .waveclussnippetstextractors import WaveClusSnippetsExtractor, read_waveclus_snippets - - -# misc -from .alfsortingextractor import ALFSortingExtractor, read_alf_sorting - - -######################################## - -recording_extractor_full_list = [ - BinaryFolderRecording, - BinaryRecordingExtractor, - ZarrRecordingExtractor, - # natively implemented in spikeinterface.extractors - NumpyRecording, - SHYBRIDRecordingExtractor, - MdaRecordingExtractor, - NwbRecordingExtractor, - # others - CompressedBinaryIblExtractor, - IblRecordingExtractor, - MCSH5RecordingExtractor, - SinapsResearchPlatformRecordingExtractor, - WhiteMatterRecordingExtractor, -] -recording_extractor_full_list += neo_recording_extractors_list - -sorting_extractor_full_list = [ - NpzSortingExtractor, - ZarrSortingExtractor, - NumpySorting, - # natively implemented in spikeinterface.extractors - MdaSortingExtractor, - SHYBRIDSortingExtractor, - ALFSortingExtractor, - KlustaSortingExtractor, - HDSortSortingExtractor, - MClustSortingExtractor, - WaveClusSortingExtractor, - YassSortingExtractor, - CombinatoSortingExtractor, - TridesclousSortingExtractor, - SpykingCircusSortingExtractor, - HerdingspikesSortingExtractor, - KiloSortSortingExtractor, - PhySortingExtractor, - NwbSortingExtractor, - NeuroScopeSortingExtractor, - IblSortingExtractor, -] -sorting_extractor_full_list += neo_sorting_extractors_list - -event_extractor_full_list = [MaxwellEventExtractor] -event_extractor_full_list += neo_event_extractors_list - -snippets_extractor_full_list = [NpySnippetsExtractor, WaveClusSnippetsExtractor] - -recording_extractor_full_dict = {} -for rec_class in recording_extractor_full_list: - # here we get the class name, remove "Recording" and "Extractor" and make it lower case - rec_class_name = rec_class.__name__.replace("Recording", "").replace("Extractor", "").lower() - recording_extractor_full_dict[rec_class_name] = rec_class - -sorting_extractor_full_dict = {} -for sort_class in sorting_extractor_full_list: - # here we get the class name, remove "Extractor" and make it lower case - sort_class_name = sort_class.__name__.replace("Sorting", "").replace("Extractor", "").lower() - sorting_extractor_full_dict[sort_class_name] = sort_class - -event_extractor_full_dict = {} -for event_class in event_extractor_full_list: - # here we get the class name, remove "Extractor" and make it lower case - event_class_name = event_class.__name__.replace("Event", "").replace("Extractor", "").lower() - event_extractor_full_dict[event_class_name] = event_class diff --git a/src/spikeinterface/extractors/iblextractors.py b/src/spikeinterface/extractors/iblextractors.py index 9c5020b098..2f672aafde 100644 --- a/src/spikeinterface/extractors/iblextractors.py +++ b/src/spikeinterface/extractors/iblextractors.py @@ -359,5 +359,5 @@ def __init__( self._kwargs = dict(pid=pid, good_clusters_only=good_clusters_only, load_unit_properties=load_unit_properties) -read_ibl_recording = define_function_from_class(source_class=IblRecordingExtractor, name="read_ibl_streaming_recording") +read_ibl_recording = define_function_from_class(source_class=IblRecordingExtractor, name="read_ibl_recording") read_ibl_sorting = define_function_from_class(source_class=IblSortingExtractor, name="read_ibl_sorting") diff --git a/src/spikeinterface/extractors/neoextractors/__init__.py b/src/spikeinterface/extractors/neoextractors/__init__.py index 03d517b46e..269731d5c3 100644 --- a/src/spikeinterface/extractors/neoextractors/__init__.py +++ b/src/spikeinterface/extractors/neoextractors/__init__.py @@ -42,42 +42,47 @@ from .neo_utils import get_neo_streams, get_neo_num_blocks -neo_recording_extractors_list = [ - AlphaOmegaRecordingExtractor, - AxonaRecordingExtractor, - BiocamRecordingExtractor, - BlackrockRecordingExtractor, - CedRecordingExtractor, - EDFRecordingExtractor, - IntanRecordingExtractor, - MaxwellRecordingExtractor, - MEArecRecordingExtractor, - MCSRawRecordingExtractor, - NeuralynxRecordingExtractor, - NeuroScopeRecordingExtractor, - NeuroNexusRecordingExtractor, - NixRecordingExtractor, - OpenEphysBinaryRecordingExtractor, - OpenEphysLegacyRecordingExtractor, - PlexonRecordingExtractor, - Plexon2RecordingExtractor, - Spike2RecordingExtractor, - SpikeGadgetsRecordingExtractor, - SpikeGLXRecordingExtractor, - TdtRecordingExtractor, - NeuroExplorerRecordingExtractor, -] +neo_recording_extractors_dict = { + AlphaOmegaRecordingExtractor: dict(wrapper_string="read_alphaomega", wrapper_class=read_alphaomega), + AxonaRecordingExtractor: dict(wrapper_string="read_axona", wrapper_class=read_axona), + BiocamRecordingExtractor: dict(wrapper_string="read_biocam", wrapper_class=read_biocam), + BlackrockRecordingExtractor: dict(wrapper_string="read_blackrock", wrapper_class=read_blackrock), + CedRecordingExtractor: dict(wrapper_string="read_ced", wrapper_class=read_ced), + EDFRecordingExtractor: dict(wrapper_string="read_edf", wrapper_class=read_edf), + IntanRecordingExtractor: dict(wrapper_string="read_intan", wrapper_class=read_intan), + MaxwellRecordingExtractor: dict(wrapper_string="read_maxwell", wrapper_class=read_maxwell), + MEArecRecordingExtractor: dict(wrapper_string="read_mearec", wrapper_class=read_mearec), + MCSRawRecordingExtractor: dict(wrapper_string="read_mcsraw", wrapper_class=read_mcsraw), + NeuralynxRecordingExtractor: dict(wrapper_string="read_neuralynx", wrapper_class=read_neuralynx), + NeuroScopeRecordingExtractor: dict( + wrapper_string="read_neuroscope_recording", wrapper_class=read_neuroscope_recording + ), + NeuroNexusRecordingExtractor: dict(wrapper_string="read_neuronexus", wrapper_class=read_neuronexus), + NixRecordingExtractor: dict(wrapper_string="read_nix", wrapper_class=read_nix), + OpenEphysBinaryRecordingExtractor: dict(wrapper_string="read_openephys", wrapper_class=read_openephys), + OpenEphysLegacyRecordingExtractor: dict(wrapper_string="read_openephys", wrapper_class=read_openephys), + PlexonRecordingExtractor: dict(wrapper_string="read_plexon", wrapper_class=read_plexon), + Plexon2RecordingExtractor: dict(wrapper_string="read_plexon2", wrapper_class=read_plexon2), + Spike2RecordingExtractor: dict(wrapper_string="read_spike2", wrapper_class=read_spike2), + SpikeGadgetsRecordingExtractor: dict(wrapper_string="read_spikegadgets", wrapper_class=read_spikegadgets), + SpikeGLXRecordingExtractor: dict(wrapper_string="read_spikeglx", wrapper_class=read_spikeglx), + TdtRecordingExtractor: dict(wrapper_string="read_tdt", wrapper_class=read_tdt), + NeuroExplorerRecordingExtractor: dict(wrapper_string="read_neuroexplorer", wrapper_class=read_neuroexplorer), +} -neo_sorting_extractors_list = [ - BlackrockSortingExtractor, - MEArecSortingExtractor, - NeuralynxSortingExtractor, - Plexon2SortingExtractor, -] +neo_sorting_extractors_dict = { + BlackrockSortingExtractor: dict(wrapper_string="read_blackrock_sorting", wrapper_class=read_blackrock_sorting), + MEArecSortingExtractor: dict(wrapper_string="read_mearec", wrapper_class=read_mearec), + NeuralynxSortingExtractor: dict(wrapper_string="read_neuralynx_sorting", wrapper_class=read_neuralynx_sorting), + PlexonSortingExtractor: dict(wrapper_string="read_plexon_sorting", wrapper_class=read_plexon_sorting), + Plexon2SortingExtractor: dict(wrapper_string="read_plexon2_sorting", wrapper_class=read_plexon2_sorting), + NeuroScopeSortingExtractor: dict(wrapper_string="read_neuroscope_sorting", wrapper_class=read_neuroscope_sorting), +} -neo_event_extractors_list = [ - AlphaOmegaEventExtractor, - OpenEphysBinaryEventExtractor, - Plexon2EventExtractor, - SpikeGLXEventExtractor, -] +neo_event_extractors_dict = { + AlphaOmegaEventExtractor: dict(wrapper_string="read_alphaomega_event", wrapper_class=read_alphaomega_event), + OpenEphysBinaryEventExtractor: dict(wrapper_string="read_openephys_event", wrapper_class=read_openephys_event), + Plexon2EventExtractor: dict(wrapper_string="read_plexon2_event", wrapper_class=read_plexon2_event), + SpikeGLXEventExtractor: dict(wrapper_string="read_spikeglx_event", wrapper_class=read_spikeglx_event), + MaxwellEventExtractor: dict(wrapper_string="read_maxwell_event", wrapper_class=read_maxwell_event), +} diff --git a/src/spikeinterface/extractors/neoextractors/neo_utils.py b/src/spikeinterface/extractors/neoextractors/neo_utils.py index 3de83ff607..ac705e8b4e 100644 --- a/src/spikeinterface/extractors/neoextractors/neo_utils.py +++ b/src/spikeinterface/extractors/neoextractors/neo_utils.py @@ -56,7 +56,7 @@ def get_neo_num_blocks(extractor_name, *args, **kwargs) -> int: def get_neo_extractor(extractor_name): - from spikeinterface.extractors.extractorlist import recording_extractor_full_dict + from spikeinterface.extractors.extractor_classes import recording_extractor_full_dict assert extractor_name in recording_extractor_full_dict, ( f"{extractor_name} not an extractor name:" f"\n{list(recording_extractor_full_dict.keys())}" diff --git a/src/spikeinterface/extractors/tests/test_cbin_ibl_extractors.py b/src/spikeinterface/extractors/tests/test_cbin_ibl_extractors.py index 2e364b13bc..905994197e 100644 --- a/src/spikeinterface/extractors/tests/test_cbin_ibl_extractors.py +++ b/src/spikeinterface/extractors/tests/test_cbin_ibl_extractors.py @@ -1,8 +1,7 @@ import pytest -import numpy as np import unittest -from spikeinterface.extractors import CompressedBinaryIblExtractor, read_cbin_ibl +from spikeinterface.extractors.extractor_classes import CompressedBinaryIblExtractor, read_cbin_ibl from spikeinterface.extractors.tests.common_tests import RecordingCommonTestSuite, SortingCommonTestSuite diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index 334ca63d6a..adec4213e3 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -6,7 +6,7 @@ from numpy.testing import assert_array_equal import pytest -from spikeinterface.extractors import read_ibl_recording, read_ibl_sorting, IblRecordingExtractor +from spikeinterface.extractors.extractor_classes import read_ibl_recording, read_ibl_sorting, IblRecordingExtractor EID = "e2b845a1-e313-4a08-bc61-a5f662ed295e" PID = "80f6ffdd-f692-450f-ab19-cd6d45bfd73e" diff --git a/src/spikeinterface/extractors/tests/test_mdaextractors.py b/src/spikeinterface/extractors/tests/test_mdaextractors.py index ef8c73ac64..37dd9720fd 100644 --- a/src/spikeinterface/extractors/tests/test_mdaextractors.py +++ b/src/spikeinterface/extractors/tests/test_mdaextractors.py @@ -2,7 +2,7 @@ from pathlib import Path from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal from spikeinterface.core import generate_ground_truth_recording -from spikeinterface.extractors import MdaRecordingExtractor, MdaSortingExtractor +from spikeinterface.extractors.extractor_classes import MdaRecordingExtractor, MdaSortingExtractor def test_mda_extractors(create_cache_folder): diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index 395f1e0d9e..7ef3a362e8 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -8,7 +8,41 @@ import pytest from spikeinterface import get_global_dataset_folder -from spikeinterface.extractors import * +from spikeinterface.extractors.extractor_classes import ( + MEArecRecordingExtractor, + MEArecSortingExtractor, + SpikeGLXRecordingExtractor, + OpenEphysBinaryRecordingExtractor, + OpenEphysBinaryEventExtractor, + OpenEphysLegacyRecordingExtractor, + IntanRecordingExtractor, + NeuroScopeRecordingExtractor, + NeuroExplorerRecordingExtractor, + NeuroScopeSortingExtractor, + NeuroNexusRecordingExtractor, + PlexonRecordingExtractor, + PlexonSortingExtractor, + NeuralynxRecordingExtractor, + AlphaOmegaEventExtractor, + SpikeGadgetsRecordingExtractor, + Plexon2SortingExtractor, + NeuralynxSortingExtractor, + BlackrockRecordingExtractor, + BlackrockSortingExtractor, + MCSRawRecordingExtractor, + TdtRecordingExtractor, + BiocamRecordingExtractor, + AxonaRecordingExtractor, + Plexon2EventExtractor, + MaxwellRecordingExtractor, + CedRecordingExtractor, + AlphaOmegaRecordingExtractor, + Spike2RecordingExtractor, + EDFRecordingExtractor, + Plexon2RecordingExtractor, +) + +from spikeinterface.extractors.extractor_classes import KiloSortSortingExtractor from spikeinterface.extractors.tests.common_tests import ( RecordingCommonTestSuite, diff --git a/src/spikeinterface/extractors/tests/test_nwbextractors.py b/src/spikeinterface/extractors/tests/test_nwbextractors.py index 4bfc43dd69..15d3e8fee9 100644 --- a/src/spikeinterface/extractors/tests/test_nwbextractors.py +++ b/src/spikeinterface/extractors/tests/test_nwbextractors.py @@ -5,7 +5,11 @@ import pytest import numpy as np -from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor, NwbTimeSeriesExtractor +from spikeinterface.extractors.extractor_classes import ( + NwbRecordingExtractor, + NwbSortingExtractor, + NwbTimeSeriesExtractor, +) from spikeinterface.extractors.tests.common_tests import RecordingCommonTestSuite, SortingCommonTestSuite from spikeinterface.core.testing import check_recordings_equal diff --git a/src/spikeinterface/extractors/tests/test_nwbextractors_streaming.py b/src/spikeinterface/extractors/tests/test_nwbextractors_streaming.py index 9724ec3d9f..84ae3c03bf 100644 --- a/src/spikeinterface/extractors/tests/test_nwbextractors_streaming.py +++ b/src/spikeinterface/extractors/tests/test_nwbextractors_streaming.py @@ -7,7 +7,7 @@ from spikeinterface import load from spikeinterface.core.testing import check_recordings_equal from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal -from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor +from spikeinterface.extractors.extractor_classes import NwbRecordingExtractor, NwbSortingExtractor @pytest.mark.streaming_extractors diff --git a/src/spikeinterface/extractors/tests/test_shybridextractors.py b/src/spikeinterface/extractors/tests/test_shybridextractors.py index 221e1bfc2d..bc6f17d736 100644 --- a/src/spikeinterface/extractors/tests/test_shybridextractors.py +++ b/src/spikeinterface/extractors/tests/test_shybridextractors.py @@ -2,7 +2,7 @@ from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core.testing import check_recordings_equal, check_sortings_equal -from spikeinterface.extractors import SHYBRIDRecordingExtractor, SHYBRIDSortingExtractor +from spikeinterface.extractors.extractor_classes import SHYBRIDRecordingExtractor, SHYBRIDSortingExtractor @pytest.mark.skipif(True, reason="SHYBRID only tested locally") diff --git a/src/spikeinterface/extractors/tests/test_whitematterrecordingextractor.py b/src/spikeinterface/extractors/tests/test_whitematterrecordingextractor.py index 546949d60f..ebd50b1822 100644 --- a/src/spikeinterface/extractors/tests/test_whitematterrecordingextractor.py +++ b/src/spikeinterface/extractors/tests/test_whitematterrecordingextractor.py @@ -3,7 +3,7 @@ from pathlib import Path import pickle -from spikeinterface.extractors import WhiteMatterRecordingExtractor, BinaryRecordingExtractor +from spikeinterface.extractors.extractor_classes import WhiteMatterRecordingExtractor, BinaryRecordingExtractor from spikeinterface.core.numpyextractors import NumpyRecording from spikeinterface.core.testing import check_recordings_equal from spikeinterface import get_global_dataset_folder, download_dataset diff --git a/src/spikeinterface/preprocessing/__init__.py b/src/spikeinterface/preprocessing/__init__.py index 4f63eb7900..bfa16d1670 100644 --- a/src/spikeinterface/preprocessing/__init__.py +++ b/src/spikeinterface/preprocessing/__init__.py @@ -1,4 +1,4 @@ -from .preprocessinglist import * +from .preprocessing_classes import * from .motion import ( correct_motion, @@ -16,3 +16,34 @@ # for snippets from .align_snippets import AlignSnippets +from warnings import warn + + +# deprecation of class import idea from neuroconv +# this __getattr__ is only triggered if the normal lookup fails so import +# any of our functions is fine but if someone tries to import a class this raises +# the warning and then returns the "function" version which will look the same +# to the end-user +# to be removed after version 0.105.0 +def __getattr__(preprocessor_name): + # we need this trick to allow us to use import * for spikeinterface.full + if preprocessor_name == "__all__": + __all__ = [] + for imp in globals(): + # need to remove a bunch of builtins etc that shouldn't be part of all + if imp[0] != "_" and imp != "warn" and imp != "preprocessor_name": + __all__.append(imp) + return __all__ + from .preprocessing_classes import _all_preprocesser_dict + + for pp_class, pp_function in _all_preprocesser_dict.items(): + if preprocessor_name == pp_class.__name__: + dep_msg = ( + "Importing classes at __init__ has been deprecated in favor of only importing functions " + "and will be removed in 0.105.0. For developers that prefer working with the class versions of preprocessors " + "they can be imported from spikeinterface.preprocessors.preprocessor_classes." + ) + warn(dep_msg) + return pp_function + # this is necessary for objects that we don't support + raise AttributeError(f"cannot import name '{preprocessor_name}' from '{__name__}'") diff --git a/src/spikeinterface/preprocessing/preprocessinglist.py b/src/spikeinterface/preprocessing/preprocessing_classes.py similarity index 51% rename from src/spikeinterface/preprocessing/preprocessinglist.py rename to src/spikeinterface/preprocessing/preprocessing_classes.py index 0745da4d64..705ed3428b 100644 --- a/src/spikeinterface/preprocessing/preprocessinglist.py +++ b/src/spikeinterface/preprocessing/preprocessing_classes.py @@ -25,8 +25,8 @@ CenterRecording, center, ) -from .scale import scale_to_uV, scale_to_physical_units +from .scale import scale_to_uV, ScaleToPhysicalUnits, scale_to_physical_units from .whiten import WhitenRecording, whiten, compute_whitening_matrix from .rectify import RectifyRecording, rectify @@ -45,41 +45,49 @@ from .astype import AstypeRecording, astype from .unsigned_to_signed import UnsignedToSignedRecording, unsigned_to_signed - -preprocessers_full_list = [ +_all_preprocesser_dict = { # filter stuff - FilterRecording, - BandpassFilterRecording, - HighpassFilterRecording, - NotchFilterRecording, - GaussianFilterRecording, + FilterRecording: filter, + BandpassFilterRecording: bandpass_filter, + HighpassFilterRecording: highpass_filter, + NotchFilterRecording: notch_filter, + GaussianFilterRecording: gaussian_filter, # gain offset stuff - NormalizeByQuantileRecording, - ScaleRecording, - CenterRecording, - ZScoreRecording, + NormalizeByQuantileRecording: normalize_by_quantile, + ScaleRecording: scale, + CenterRecording: center, + ZScoreRecording: zscore, + ScaleToPhysicalUnits: scale_to_physical_units, # decorrelation stuff - WhitenRecording, + WhitenRecording: whiten, # re-reference - CommonReferenceRecording, - PhaseShiftRecording, + CommonReferenceRecording: common_reference, + PhaseShiftRecording: phase_shift, # misc - RectifyRecording, - ClipRecording, - BlankSaturationRecording, - SilencedPeriodsRecording, - RemoveArtifactsRecording, - ZeroChannelPaddedRecording, - DeepInterpolatedRecording, - ResampleRecording, - DecimateRecording, - HighpassSpatialFilterRecording, - InterpolateBadChannelsRecording, - DepthOrderRecording, - AverageAcrossDirectionRecording, - DirectionalDerivativeRecording, - AstypeRecording, - UnsignedToSignedRecording, -] + RectifyRecording: rectify, + ClipRecording: clip, + BlankSaturationRecording: blank_saturation, + SilencedPeriodsRecording: silence_periods, + RemoveArtifactsRecording: remove_artifacts, + ZeroChannelPaddedRecording: zero_channel_pad, + DeepInterpolatedRecording: deepinterpolate, + ResampleRecording: resample, + DecimateRecording: decimate, + HighpassSpatialFilterRecording: highpass_spatial_filter, + InterpolateBadChannelsRecording: interpolate_bad_channels, + DepthOrderRecording: depth_order, + AverageAcrossDirectionRecording: average_across_direction, + DirectionalDerivativeRecording: directional_derivative, + AstypeRecording: astype, + UnsignedToSignedRecording: unsigned_to_signed, +} +# we control import in the preprocessing init by setting an __all__ + +# pp_function.__name__ gives the name of the function that users should use +__all__ = [pp_function.__name__ for pp_function in _all_preprocesser_dict.values()] +__all__.extend( + [scale_to_uV.__name__, compute_whitening_matrix.__name__, train_deepinterpolation.__name__, causal_filter.__name__] +) -preprocesser_dict = {pp_class.name: pp_class for pp_class in preprocessers_full_list} +preprocessor_dict = {pp_class.__name__: pp_function for pp_class, pp_function in _all_preprocesser_dict.items()} +__all__.append("preprocessor_dict") diff --git a/src/spikeinterface/preprocessing/scale.py b/src/spikeinterface/preprocessing/scale.py index 3cb8719e58..9550583b75 100644 --- a/src/spikeinterface/preprocessing/scale.py +++ b/src/spikeinterface/preprocessing/scale.py @@ -86,7 +86,7 @@ def scale_to_uV(recording: BasePreprocessor) -> BasePreprocessor: If the recording extractor does not have scaleable traces. """ # To avoid a circular import - from spikeinterface.preprocessing import ScaleRecording + from spikeinterface.preprocessing.preprocessing_classes import ScaleRecording if not recording.has_scaleable_traces(): error_msg = "Recording must have gains and offsets set to be scaled to µV" diff --git a/src/spikeinterface/preprocessing/tests/test_average_across_direction.py b/src/spikeinterface/preprocessing/tests/test_average_across_direction.py index c0965d8e51..aa96841aaa 100644 --- a/src/spikeinterface/preprocessing/tests/test_average_across_direction.py +++ b/src/spikeinterface/preprocessing/tests/test_average_across_direction.py @@ -1,10 +1,6 @@ -import pytest -from pathlib import Path - -from spikeinterface import set_global_tmp_folder from spikeinterface.core import NumpyRecording -from spikeinterface.preprocessing import AverageAcrossDirectionRecording, average_across_direction +from spikeinterface.preprocessing import average_across_direction import numpy as np diff --git a/src/spikeinterface/preprocessing/tests/test_decimate.py b/src/spikeinterface/preprocessing/tests/test_decimate.py index aab17560a6..e9493145a6 100644 --- a/src/spikeinterface/preprocessing/tests/test_decimate.py +++ b/src/spikeinterface/preprocessing/tests/test_decimate.py @@ -1,7 +1,6 @@ import pytest -from pathlib import Path -import itertools + from spikeinterface import NumpyRecording from spikeinterface.core import generate_recording from spikeinterface.preprocessing.decimate import DecimateRecording diff --git a/src/spikeinterface/preprocessing/tests/test_depth_order.py b/src/spikeinterface/preprocessing/tests/test_depth_order.py index b0dbc2a8da..3b4ea8f1f2 100644 --- a/src/spikeinterface/preprocessing/tests/test_depth_order.py +++ b/src/spikeinterface/preprocessing/tests/test_depth_order.py @@ -1,8 +1,6 @@ -import pytest - from spikeinterface.core import NumpyRecording -from spikeinterface.preprocessing import DepthOrderRecording, depth_order +from spikeinterface.preprocessing import depth_order import numpy as np diff --git a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py index 4622be1440..f0e22c4afd 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py @@ -1,5 +1,7 @@ import pytest import numpy as np +import importlib.util + from spikeinterface import NumpyRecording, get_random_data_chunks from probeinterface import generate_linear_probe @@ -7,14 +9,14 @@ from spikeinterface.core import generate_recording from spikeinterface.preprocessing import detect_bad_channels, highpass_filter -try: - # WARNING : this is not this package https://pypi.org/project/neurodsp/ - # BUT this one https://github.com/int-brain-lab/ibl-neuropixel - # pip install ibl-neuropixel - import neurodsp.voltage - +# WARNING : this is not this package https://pypi.org/project/neurodsp/ +# BUT this one https://github.com/int-brain-lab/ibl-neuropixel +# pip install ibl-neuropixel +# note the check needs to find package first before submodule otherwise the check will fail if the library +# does not exist +if importlib.util.find_spec("neurodsp") is not None and importlib.util.find_spec("neurodsp.voltage") is not None: HAVE_NPIX = True -except: # Catch relevant exception +else: HAVE_NPIX = False @@ -115,6 +117,8 @@ def test_detect_bad_channels_ibl(num_channels): however for testing it is necssary. So before calling the IBL function we need to rescale the traces to Volts. """ + import neurodsp.voltage + # download_path = si.download_dataset(remote_path='spikeglx/Noise4Sam_g0') # recording = se.read_spikeglx(download_path, stream_id="imec0.ap") recording = generate_recording(num_channels=num_channels, durations=[1]) @@ -166,9 +170,9 @@ def test_detect_bad_channels_ibl(num_channels): channel_flags_ibl[:, i] = channel_flags # Take the mode of the chunk estimates as final result. Convert to binary good / bad channel output. - import scipy.stats + from scipy.stats import mode - bad_channel_labels_ibl, _ = scipy.stats.mode(channel_flags_ibl, axis=1, keepdims=False) + bad_channel_labels_ibl, _ = mode(channel_flags_ibl, axis=1, keepdims=False) # Compare channels_labeled_as_good = bad_channel_labels_si == "good" @@ -216,7 +220,7 @@ def reduce_high_freq_power_in_non_noisy_channels(recording, is_noisy, not_noisy) Reduce power in >80% Nyquist for all channels except noisy channels to 20% of original. Return the psd_cutoff in uV^2/Hz that separates the good at noisy channels. """ - import scipy.signal + from scipy.signal import welch for iseg, __ in enumerate(recording._recording_segments): data = recording.get_traces(iseg).T @@ -230,7 +234,7 @@ def reduce_high_freq_power_in_non_noisy_channels(recording, is_noisy, not_noisy) data[not_noisy] = np.fft.ifft(np.fft.ifftshift(D)) # calculate the psd_cutoff (which separates noisy and non-noisy) ad-hoc from the last segment - fscale, psd = scipy.signal.welch(data, fs=recording.get_sampling_frequency()) + fscale, psd = welch(data, fs=recording.get_sampling_frequency()) psd_cutoff = np.mean([np.mean(psd[not_noisy, -50:]), np.mean(psd[is_noisy, -50:])]) return psd_cutoff diff --git a/src/spikeinterface/preprocessing/tests/test_directional_derivative.py b/src/spikeinterface/preprocessing/tests/test_directional_derivative.py index 5f887ae35c..c12602fb12 100644 --- a/src/spikeinterface/preprocessing/tests/test_directional_derivative.py +++ b/src/spikeinterface/preprocessing/tests/test_directional_derivative.py @@ -1,6 +1,6 @@ from spikeinterface.core import NumpyRecording -from spikeinterface.preprocessing import DirectionalDerivativeRecording, directional_derivative +from spikeinterface.preprocessing import directional_derivative import numpy as np diff --git a/src/spikeinterface/preprocessing/tests/test_scaling.py b/src/spikeinterface/preprocessing/tests/test_scaling.py index 380b345329..aeae825c73 100644 --- a/src/spikeinterface/preprocessing/tests/test_scaling.py +++ b/src/spikeinterface/preprocessing/tests/test_scaling.py @@ -1,7 +1,7 @@ import pytest import numpy as np -from spikeinterface.core.generate import generate_recording -from spikeinterface.preprocessing import scale_to_uV, CenterRecording, scale_to_physical_units +from spikeinterface.core.testing_tools import generate_recording +from spikeinterface.preprocessing.preprocessing_classes import scale_to_uV, CenterRecording, scale_to_physical_units def test_scale_to_uV(): diff --git a/src/spikeinterface/sorters/external/combinato.py b/src/spikeinterface/sorters/external/combinato.py index 14bf218cdc..f67a01094e 100644 --- a/src/spikeinterface/sorters/external/combinato.py +++ b/src/spikeinterface/sorters/external/combinato.py @@ -9,7 +9,7 @@ from spikeinterface.sorters.utils import ShellScript from spikeinterface.core import write_to_h5_dataset_format from spikeinterface.sorters.basesorter import BaseSorter -from spikeinterface.extractors import CombinatoSortingExtractor +from spikeinterface.extractors.extractor_classes import CombinatoSortingExtractor PathType = Union[str, Path] diff --git a/src/spikeinterface/sorters/external/hdsort.py b/src/spikeinterface/sorters/external/hdsort.py index c5f510a631..8d3e3c44a3 100644 --- a/src/spikeinterface/sorters/external/hdsort.py +++ b/src/spikeinterface/sorters/external/hdsort.py @@ -13,7 +13,7 @@ from spikeinterface.sorters.utils import ShellScript # from spikeinterface.extractors import MaxOneRecordingExtractor -from spikeinterface.extractors import HDSortSortingExtractor +from spikeinterface.extractors.extractor_classes import HDSortSortingExtractor PathType = Union[str, Path] diff --git a/src/spikeinterface/sorters/external/herdingspikes.py b/src/spikeinterface/sorters/external/herdingspikes.py index 30535b48ed..70416bb944 100644 --- a/src/spikeinterface/sorters/external/herdingspikes.py +++ b/src/spikeinterface/sorters/external/herdingspikes.py @@ -5,7 +5,7 @@ from spikeinterface.sorters.basesorter import BaseSorter -from spikeinterface.extractors import HerdingspikesSortingExtractor +from spikeinterface.extractors.extractor_classes import HerdingspikesSortingExtractor class HerdingspikesSorter(BaseSorter): diff --git a/src/spikeinterface/sorters/external/ironclust.py b/src/spikeinterface/sorters/external/ironclust.py index 4373cad254..1928918ccc 100644 --- a/src/spikeinterface/sorters/external/ironclust.py +++ b/src/spikeinterface/sorters/external/ironclust.py @@ -8,7 +8,7 @@ from spikeinterface.sorters.utils import ShellScript from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs -from spikeinterface.extractors import MdaRecordingExtractor, MdaSortingExtractor +from spikeinterface.extractors.extractor_classes import MdaRecordingExtractor, MdaSortingExtractor PathType = Union[str, Path] diff --git a/src/spikeinterface/sorters/external/kilosortbase.py b/src/spikeinterface/sorters/external/kilosortbase.py index 0e185a4051..97886868c1 100644 --- a/src/spikeinterface/sorters/external/kilosortbase.py +++ b/src/spikeinterface/sorters/external/kilosortbase.py @@ -10,7 +10,7 @@ from spikeinterface.sorters.utils import ShellScript, get_matlab_shell_name, get_bash_path from spikeinterface.sorters.basesorter import get_job_kwargs -from spikeinterface.extractors import KiloSortSortingExtractor +from spikeinterface.extractors.extractor_classes import KiloSortSortingExtractor from spikeinterface.core import write_binary_recording from spikeinterface.preprocessing.zero_channel_pad import TracePaddedRecording diff --git a/src/spikeinterface/sorters/external/klusta.py b/src/spikeinterface/sorters/external/klusta.py index d96eb74789..b6602ba6e2 100644 --- a/src/spikeinterface/sorters/external/klusta.py +++ b/src/spikeinterface/sorters/external/klusta.py @@ -12,7 +12,7 @@ from probeinterface import write_prb from spikeinterface.core import write_binary_recording -from spikeinterface.extractors import KlustaSortingExtractor +from spikeinterface.extractors.extractor_classes import KlustaSortingExtractor class KlustaSorter(BaseSorter): diff --git a/src/spikeinterface/sorters/external/mountainsort4.py b/src/spikeinterface/sorters/external/mountainsort4.py index d0b0f4f118..f79b29eb73 100644 --- a/src/spikeinterface/sorters/external/mountainsort4.py +++ b/src/spikeinterface/sorters/external/mountainsort4.py @@ -7,7 +7,7 @@ from spikeinterface.preprocessing import bandpass_filter, whiten from spikeinterface.sorters.basesorter import BaseSorter from spikeinterface.core.old_api_utils import NewToOldRecording -from spikeinterface.extractors import NpzSortingExtractor, NumpySorting +from spikeinterface.extractors.extractor_classes import NpzSortingExtractor, NumpySorting class Mountainsort4Sorter(BaseSorter): diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index 46fe021ad7..3c20010ab5 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -13,7 +13,7 @@ from spikeinterface.preprocessing import bandpass_filter, whiten from spikeinterface.core.baserecording import BaseRecording from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs -from spikeinterface.extractors import NpzSortingExtractor +from spikeinterface.extractors.extractor_classes import NpzSortingExtractor class Mountainsort5Sorter(BaseSorter): diff --git a/src/spikeinterface/sorters/external/pykilosort.py b/src/spikeinterface/sorters/external/pykilosort.py index 76eb291c79..c5d8dddfb9 100644 --- a/src/spikeinterface/sorters/external/pykilosort.py +++ b/src/spikeinterface/sorters/external/pykilosort.py @@ -8,7 +8,7 @@ import numpy as np -from spikeinterface.extractors import KiloSortSortingExtractor +from spikeinterface.extractors.extractor_classes import KiloSortSortingExtractor from spikeinterface.core import write_binary_recording from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs diff --git a/src/spikeinterface/sorters/external/spyking_circus.py b/src/spikeinterface/sorters/external/spyking_circus.py index c9236407f1..2df2c97be5 100644 --- a/src/spikeinterface/sorters/external/spyking_circus.py +++ b/src/spikeinterface/sorters/external/spyking_circus.py @@ -10,7 +10,7 @@ from numpy.lib.format import open_memmap -from spikeinterface.extractors import SpykingCircusSortingExtractor +from spikeinterface.extractors.extractor_classes import SpykingCircusSortingExtractor from spikeinterface.sorters.basesorter import BaseSorter from spikeinterface.sorters.utils import ShellScript diff --git a/src/spikeinterface/sorters/external/tridesclous.py b/src/spikeinterface/sorters/external/tridesclous.py index 28596cc5b4..65aadd1d55 100644 --- a/src/spikeinterface/sorters/external/tridesclous.py +++ b/src/spikeinterface/sorters/external/tridesclous.py @@ -6,7 +6,7 @@ import importlib.util from importlib.metadata import version -from spikeinterface.extractors import TridesclousSortingExtractor +from spikeinterface.extractors.extractor_classes import TridesclousSortingExtractor from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs from spikeinterface.core import write_binary_recording diff --git a/src/spikeinterface/sorters/external/waveclus.py b/src/spikeinterface/sorters/external/waveclus.py index 90e3e5ef25..42561a2e4a 100644 --- a/src/spikeinterface/sorters/external/waveclus.py +++ b/src/spikeinterface/sorters/external/waveclus.py @@ -13,7 +13,7 @@ from spikeinterface.sorters.utils import ShellScript from spikeinterface.core import write_to_h5_dataset_format -from spikeinterface.extractors import WaveClusSortingExtractor +from spikeinterface.extractors.extractor_classes import WaveClusSortingExtractor from spikeinterface.core.channelslice import ChannelSliceRecording PathType = Union[str, Path] diff --git a/src/spikeinterface/sorters/external/waveclus_snippets.py b/src/spikeinterface/sorters/external/waveclus_snippets.py index 20d45b8c7f..3bd2c98a3c 100644 --- a/src/spikeinterface/sorters/external/waveclus_snippets.py +++ b/src/spikeinterface/sorters/external/waveclus_snippets.py @@ -11,8 +11,8 @@ from spikeinterface.sorters.basesorter import BaseSorter from spikeinterface.sorters.utils import ShellScript -from spikeinterface.extractors import WaveClusSortingExtractor -from spikeinterface.extractors import WaveClusSnippetsExtractor +from spikeinterface.extractors.extractor_classes import WaveClusSortingExtractor +from spikeinterface.extractors.extractor_classes import WaveClusSnippetsExtractor PathType = Union[str, Path] diff --git a/src/spikeinterface/sorters/external/yass.py b/src/spikeinterface/sorters/external/yass.py index 11ac4edd87..aabc8f77b4 100644 --- a/src/spikeinterface/sorters/external/yass.py +++ b/src/spikeinterface/sorters/external/yass.py @@ -11,7 +11,7 @@ from spikeinterface.sorters.utils import ShellScript from spikeinterface.core import write_binary_recording -from spikeinterface.extractors import YassSortingExtractor +from spikeinterface.extractors.extractor_classes import YassSortingExtractor class YassSorter(BaseSorter): diff --git a/src/spikeinterface/sortingcomponents/tests/test_triage.py b/src/spikeinterface/sortingcomponents/tests/test_triage.py index 93ebb564e1..009021c69a 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_triage.py +++ b/src/spikeinterface/sortingcomponents/tests/test_triage.py @@ -3,7 +3,7 @@ from spikeinterface import download_dataset -from spikeinterface.extractors import MEArecRecordingExtractor +from spikeinterface.extractors.neoextractors import MEArecRecordingExtractor from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_localization import localize_peaks diff --git a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_neural_network_denoiser.py b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_neural_network_denoiser.py index 1823b0f438..6ae694c083 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_neural_network_denoiser.py +++ b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_neural_network_denoiser.py @@ -1,9 +1,3 @@ -import numpy as np -import pytest - -from spikeinterface.extractors import MEArecRecordingExtractor -from spikeinterface import download_dataset - from spikeinterface.core.node_pipeline import run_node_pipeline, PeakRetriever, ExtractDenseWaveforms from spikeinterface.sortingcomponents.waveforms.neural_network_denoiser import SingleChannelToyDenoiser