55
66from probeinterface import ProbeGroup
77
8- from .base import minimum_spike_dtype
8+ from .base import minimum_spike_dtype , _get_class_from_string
99from .baserecording import BaseRecording , BaseRecordingSegment
1010from .basesorting import BaseSorting , SpikeVectorSortingSegment
11- from .core_tools import define_function_from_class , check_json
11+ from .core_tools import define_function_from_class , check_json , retrieve_importing_provenance
1212from .job_tools import split_job_kwargs
1313from .core_tools import is_path_remote
1414
@@ -212,6 +212,7 @@ def write_recording(
212212 recording : BaseRecording , folder_path : str | Path , storage_options : dict | None = None , ** kwargs
213213 ):
214214 zarr_root = zarr .open (str (folder_path ), mode = "w" , storage_options = storage_options )
215+ zarr_root .attrs ["zarr_class_info" ] = retrieve_importing_provenance (ZarrRecordingExtractor )
215216 add_recording_to_zarr_group (recording , zarr_root , ** kwargs )
216217
217218
@@ -320,6 +321,7 @@ def write_sorting(sorting: BaseSorting, folder_path: str | Path, storage_options
320321 Write a sorting extractor to zarr format.
321322 """
322323 zarr_root = zarr .open (str (folder_path ), mode = "w" , storage_options = storage_options )
324+ zarr_root .attrs ["zarr_class_info" ] = retrieve_importing_provenance (ZarrSortingExtractor )
323325 add_sorting_to_zarr_group (sorting , zarr_root , ** kwargs )
324326
325327
@@ -345,15 +347,22 @@ def read_zarr(
345347 extractor : ZarrExtractor
346348 The loaded extractor
347349 """
348- # TODO @alessio : we should have something more explicit in our zarr format to tell which object it is.
349- # for the futur SortingAnalyzer we will have this 2 fields!!!
350350 root = super_zarr_open (folder_path , mode = "r" , storage_options = storage_options )
351- if "channel_ids" in root .keys ():
352- return read_zarr_recording (folder_path , storage_options = storage_options )
353- elif "unit_ids" in root .keys ():
354- return read_zarr_sorting (folder_path , storage_options = storage_options )
351+ zarr_class_info = root .attrs .get ("zarr_class_info" , None )
352+ if zarr_class_info is not None :
353+ class_name = zarr_class_info ["class" ]
354+ extractor_class = _get_class_from_string (class_name )
355+ return extractor_class (folder_path , storage_options = storage_options )
355356 else :
356- raise ValueError ("Cannot find 'channel_ids' or 'unit_ids' in zarr root. Not a valid SpikeInterface zarr format" )
357+ # For version<0.105.0 zarr files, revert to old way of loading based on the presence of "channel_ids"/"unit_ids"
358+ if "channel_ids" in root .keys ():
359+ return read_zarr_recording (folder_path , storage_options = storage_options )
360+ elif "unit_ids" in root .keys ():
361+ return read_zarr_sorting (folder_path , storage_options = storage_options )
362+ else :
363+ raise ValueError (
364+ "Cannot find 'channel_ids' or 'unit_ids' in zarr root. Not a valid SpikeInterface zarr format"
365+ )
357366
358367
359368### UTILITY FUNCTIONS ###
0 commit comments