1010
1111if TYPE_CHECKING :
1212 from importlib .metadata import EntryPoint
13-
13+ from zarr . codecs . numcodec import Numcodec
1414 from zarr .abc .codec import (
1515 ArrayArrayCodec ,
1616 ArrayBytesCodec ,
@@ -53,6 +53,10 @@ def register(self, cls: type[T], qualname: str | None = None) -> None:
5353 self [qualname ] = cls
5454
5555
56+ __filter_registries : dict [str , Registry [ArrayArrayCodec ]] = defaultdict (Registry )
57+ __serializer_registries : dict [str , Registry [ArrayBytesCodec ]] = defaultdict (Registry )
58+ __compressor_registries : dict [str , Registry [BytesBytesCodec ]] = defaultdict (Registry )
59+
5660__codec_registries : dict [str , Registry [Codec ]] = defaultdict (Registry )
5761__pipeline_registry : Registry [CodecPipeline ] = Registry ()
5862__buffer_registry : Registry [Buffer ] = Registry ()
@@ -117,17 +121,59 @@ def _collect_entrypoints() -> list[Registry[Any]]:
117121def _reload_config () -> None :
118122 config .refresh ()
119123
120-
121124def fully_qualified_name (cls : type ) -> str :
122125 module = cls .__module__
123126 return module + "." + cls .__qualname__
124127
128+ def register_filter (key : str , codec_cls : type [ArrayArrayCodec ]) -> None :
129+ if key not in __filter_registries :
130+ __filter_registries [key ] = Registry ()
131+ __filter_registries [key ].register (codec_cls )
132+
133+ def register_serializer (key : str , codec_cls : type [ArrayBytesCodec ]) -> None :
134+ from zarr .codecs .numcodec import NumcodecsArrayBytesCodec , is_numcodec_cls
135+ if is_numcodec_cls (codec_cls ):
136+ _codec_cls = NumcodecsArrayBytesCodec (_codec = codec_cls )
137+ else :
138+ _codec_cls = codec_cls
139+ if key not in __serializer_registries :
140+ __serializer_registries [key ] = Registry ()
141+ __serializer_registries [key ].register (_codec_cls )
142+
143+ def register_serializer (key : str , codec_cls : type [ArrayBytesCodec ]) -> None :
144+ from zarr .codecs .numcodec import NumcodecsArrayBytesCodec , is_numcodec_cls
145+ if is_numcodec_cls (codec_cls ):
146+ _codec_cls = NumcodecsArrayBytesCodec (_codec = codec_cls )
147+ else :
148+ _codec_cls = codec_cls
149+ if key not in __serializer_registries :
150+ __serializer_registries [key ] = Registry ()
151+ __serializer_registries [key ].register (_codec_cls )
152+
153+ def register_compressor (key : str , codec_cls : type [BytesBytesCodec | Numcodec ]) -> None :
154+ from zarr .codecs .numcodec import NumcodecsBytesBytesCodec , is_numcodec_cls
155+ if is_numcodec_cls (codec_cls ):
156+ _codec_cls = NumcodecsBytesBytesCodec (_codec = codec_cls )
157+ else :
158+ _codec_cls = codec_cls
159+ if key not in __compressor_registries :
160+ __compressor_registries [key ] = Registry ()
161+ __compressor_registries [key ].register (_codec_cls )
125162
126163def register_codec (key : str , codec_cls : type [Codec ]) -> None :
164+ from zarr .abc .codec import ArrayArrayCodec , ArrayBytesCodec
165+ if issubclass (codec_cls , ArrayBytesCodec ):
166+ register_serializer (key , codec_cls )
167+ elif issubclass (codec_cls , ArrayArrayCodec ):
168+ register_filter (key , codec_cls )
169+ else :
170+ register_compressor (key , codec_cls )
171+
172+ """
127173 if key not in __codec_registries:
128174 __codec_registries[key] = Registry()
129175 __codec_registries[key].register(codec_cls)
130-
176+ """
131177
132178def register_pipeline (pipe_cls : type [CodecPipeline ]) -> None :
133179 __pipeline_registry .register (pipe_cls )
@@ -140,6 +186,41 @@ def register_ndbuffer(cls: type[NDBuffer], qualname: str | None = None) -> None:
140186def register_buffer (cls : type [Buffer ], qualname : str | None = None ) -> None :
141187 __buffer_registry .register (cls , qualname )
142188
189+ def get_filter_class (key : str , reload_config : bool = False ) -> type [ArrayArrayCodec ]:
190+ return _get_codec_class (key , __serializer_registries , reload_config = reload_config )
191+
192+ def get_serializer_class (key : str , reload_config : bool = False ) -> type [ArrayBytesCodec ]:
193+ return _get_codec_class (key , __serializer_registries , reload_config = reload_config )
194+
195+ def get_compressor_class (key : str , reload_config : bool = False ) -> type [BytesBytesCodec ]:
196+ return _get_codec_class (key , __compressor_registries , reload_config = reload_config )
197+
198+ def _get_codec_class (key : str , registry : dict [str , Registry [Codec ]], * , reload_config : bool = False ) -> type [Codec ]:
199+ if reload_config :
200+ _reload_config ()
201+
202+ if key in registry :
203+ # logger.debug("Auto loading codec '%s' from entrypoint", codec_id)
204+ registry [key ].lazy_load ()
205+
206+ codec_classes = registry [key ]
207+ if not codec_classes :
208+ raise KeyError (key )
209+
210+ config_entry = config .get ("codecs" , {}).get (key )
211+ if config_entry is None :
212+ if len (codec_classes ) == 1 :
213+ return next (iter (codec_classes .values ()))
214+ warnings .warn (
215+ f"Codec '{ key } ' not configured in config. Selecting any implementation." ,
216+ stacklevel = 2 ,
217+ )
218+ return list (codec_classes .values ())[- 1 ]
219+ selected_codec_cls = codec_classes [config_entry ]
220+
221+ if selected_codec_cls :
222+ return selected_codec_cls
223+ raise KeyError (key )
143224
144225def get_codec_class (key : str , reload_config : bool = False ) -> type [Codec ]:
145226 if reload_config :
@@ -189,7 +270,7 @@ def _parse_bytes_bytes_codec(data: dict[str, JSON] | Codec | Numcodec) -> BytesB
189270
190271 result : BytesBytesCodec
191272 if isinstance (data , dict ):
192- result = _resolve_codec (data )
273+ result = get_compressor_class ( data [ "name" ]). from_dict (data )
193274 if not isinstance (result , BytesBytesCodec ):
194275 msg = f"Expected a dict representation of a BytesBytesCodec; got a dict representation of a { type (result )} instead."
195276 raise TypeError (msg )
@@ -202,39 +283,43 @@ def _parse_bytes_bytes_codec(data: dict[str, JSON] | Codec | Numcodec) -> BytesB
202283 return result
203284
204285
205- def _parse_array_bytes_codec (data : dict [str , JSON ] | Codec ) -> ArrayBytesCodec :
286+ def _parse_array_bytes_codec (data : dict [str , JSON ] | Codec | Numcodec ) -> ArrayBytesCodec :
206287 """
207288 Normalize the input to a ``ArrayBytesCodec`` instance.
208289 If the input is already a ``ArrayBytesCodec``, it is returned as is. If the input is a dict, it
209290 is converted to a ``ArrayBytesCodec`` instance via the ``_resolve_codec`` function.
210291 """
211292 from zarr .abc .codec import ArrayBytesCodec
212-
293+ from zarr . codecs . numcodec import Numcodec , NumcodecsArrayBytesCodec
213294 if isinstance (data , dict ):
214- result = _resolve_codec (data )
295+ result = get_serializer_class ( data [ "name" ]). from_dict (data )
215296 if not isinstance (result , ArrayBytesCodec ):
216297 msg = f"Expected a dict representation of a ArrayBytesCodec; got a dict representation of a { type (result )} instead."
217298 raise TypeError (msg )
299+ elif isinstance (data , Numcodec ):
300+ return NumcodecsArrayBytesCodec (_codec = data )
218301 else :
219302 if not isinstance (data , ArrayBytesCodec ):
220303 raise TypeError (f"Expected a ArrayBytesCodec. Got { type (data )} instead." )
221304 result = data
222305 return result
223306
224307
225- def _parse_array_array_codec (data : dict [str , JSON ] | Codec ) -> ArrayArrayCodec :
308+ def _parse_array_array_codec (data : dict [str , JSON ] | Codec | Numcodec ) -> ArrayArrayCodec :
226309 """
227310 Normalize the input to a ``ArrayArrayCodec`` instance.
228311 If the input is already a ``ArrayArrayCodec``, it is returned as is. If the input is a dict, it
229312 is converted to a ``ArrayArrayCodec`` instance via the ``_resolve_codec`` function.
230313 """
231314 from zarr .abc .codec import ArrayArrayCodec
232-
315+ from zarr . codecs . numcodec import Numcodec , NumcodecsArrayArrayCodec
233316 if isinstance (data , dict ):
234- result = _resolve_codec (data )
317+ result = get_filter_class ( data [ "name" ]). from_dict (data )
235318 if not isinstance (result , ArrayArrayCodec ):
236319 msg = f"Expected a dict representation of a ArrayArrayCodec; got a dict representation of a { type (result )} instead."
237320 raise TypeError (msg )
321+ elif isinstance (data , Numcodec ):
322+ return NumcodecsArrayArrayCodec (_codec = data )
238323 else :
239324 if not isinstance (data , ArrayArrayCodec ):
240325 raise TypeError (f"Expected a ArrayArrayCodec. Got { type (data )} instead." )
0 commit comments