Skip to content

Commit 5d5ddbe

Browse files
authored
Add zarr_class_info attribute to zarr files for unambiguous and extendible zarr read (#4467)
1 parent 6cf41bb commit 5d5ddbe

1 file changed

Lines changed: 18 additions & 9 deletions

File tree

src/spikeinterface/core/zarrextractors.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
from probeinterface import ProbeGroup
77

8-
from .base import minimum_spike_dtype
8+
from .base import minimum_spike_dtype, _get_class_from_string
99
from .baserecording import BaseRecording, BaseRecordingSegment
1010
from .basesorting import BaseSorting, SpikeVectorSortingSegment
11-
from .core_tools import define_function_from_class, check_json
11+
from .core_tools import define_function_from_class, check_json, retrieve_importing_provenance
1212
from .job_tools import split_job_kwargs
1313
from .core_tools import is_path_remote
1414

@@ -212,6 +212,7 @@ def write_recording(
212212
recording: BaseRecording, folder_path: str | Path, storage_options: dict | None = None, **kwargs
213213
):
214214
zarr_root = zarr.open(str(folder_path), mode="w", storage_options=storage_options)
215+
zarr_root.attrs["zarr_class_info"] = retrieve_importing_provenance(ZarrRecordingExtractor)
215216
add_recording_to_zarr_group(recording, zarr_root, **kwargs)
216217

217218

@@ -320,6 +321,7 @@ def write_sorting(sorting: BaseSorting, folder_path: str | Path, storage_options
320321
Write a sorting extractor to zarr format.
321322
"""
322323
zarr_root = zarr.open(str(folder_path), mode="w", storage_options=storage_options)
324+
zarr_root.attrs["zarr_class_info"] = retrieve_importing_provenance(ZarrSortingExtractor)
323325
add_sorting_to_zarr_group(sorting, zarr_root, **kwargs)
324326

325327

@@ -345,15 +347,22 @@ def read_zarr(
345347
extractor : ZarrExtractor
346348
The loaded extractor
347349
"""
348-
# TODO @alessio : we should have something more explicit in our zarr format to tell which object it is.
349-
# for the futur SortingAnalyzer we will have this 2 fields!!!
350350
root = super_zarr_open(folder_path, mode="r", storage_options=storage_options)
351-
if "channel_ids" in root.keys():
352-
return read_zarr_recording(folder_path, storage_options=storage_options)
353-
elif "unit_ids" in root.keys():
354-
return read_zarr_sorting(folder_path, storage_options=storage_options)
351+
zarr_class_info = root.attrs.get("zarr_class_info", None)
352+
if zarr_class_info is not None:
353+
class_name = zarr_class_info["class"]
354+
extractor_class = _get_class_from_string(class_name)
355+
return extractor_class(folder_path, storage_options=storage_options)
355356
else:
356-
raise ValueError("Cannot find 'channel_ids' or 'unit_ids' in zarr root. Not a valid SpikeInterface zarr format")
357+
# For version<0.105.0 zarr files, revert to old way of loading based on the presence of "channel_ids"/"unit_ids"
358+
if "channel_ids" in root.keys():
359+
return read_zarr_recording(folder_path, storage_options=storage_options)
360+
elif "unit_ids" in root.keys():
361+
return read_zarr_sorting(folder_path, storage_options=storage_options)
362+
else:
363+
raise ValueError(
364+
"Cannot find 'channel_ids' or 'unit_ids' in zarr root. Not a valid SpikeInterface zarr format"
365+
)
357366

358367

359368
### UTILITY FUNCTIONS ###

0 commit comments

Comments
 (0)