|
| 1 | +""" |
| 2 | +Utilities for interfacing with the numcodecs library. |
| 3 | +""" |
| 4 | + |
| 5 | +from __future__ import annotations |
| 6 | + |
| 7 | +import asyncio |
| 8 | +from collections.abc import Mapping |
| 9 | +from dataclasses import dataclass |
| 10 | +from typing import TYPE_CHECKING, Literal, Self, overload |
| 11 | + |
| 12 | +import numpy as np |
| 13 | +from typing_extensions import Protocol, runtime_checkable |
| 14 | + |
| 15 | +from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, CodecConfig_V2 |
| 16 | +from zarr.core.array_spec import ArraySpec |
| 17 | +from zarr.core.buffer.core import Buffer, BufferPrototype, NDArrayLike, NDBuffer |
| 18 | +from zarr.core.buffer.cpu import as_numpy_array_wrapper |
| 19 | + |
| 20 | +if TYPE_CHECKING: |
| 21 | + from zarr.core.array_spec import ArraySpec |
| 22 | + from zarr.core.common import BaseConfig, NamedConfig, ZarrFormat |
| 23 | + |
| 24 | +BufferOrNDArray = Buffer | np.ndarray[tuple[int, ...], np.dtype[np.generic]] | NDArrayLike |
| 25 | + |
| 26 | + |
| 27 | +def resolve_numcodec(config: CodecConfig_V2[str]) -> Numcodec: |
| 28 | + import numcodecs |
| 29 | + |
| 30 | + return numcodecs.get_codec(config) # type: ignore[no-any-return] |
| 31 | + |
| 32 | + |
| 33 | +@runtime_checkable |
| 34 | +class Numcodec(Protocol): |
| 35 | + """ |
| 36 | + A protocol that models the ``numcodecs.abc.Codec`` interface. |
| 37 | + """ |
| 38 | + |
| 39 | + codec_id: str |
| 40 | + |
| 41 | + def encode(self, buf: BufferOrNDArray) -> BufferOrNDArray: ... |
| 42 | + |
| 43 | + def decode( |
| 44 | + self, buf: BufferOrNDArray, out: BufferOrNDArray | None = None |
| 45 | + ) -> BufferOrNDArray: ... |
| 46 | + |
| 47 | + def get_config(self) -> CodecConfig_V2[str]: ... |
| 48 | + |
| 49 | + @classmethod |
| 50 | + def from_config(cls, config: CodecConfig_V2[str]) -> Self: ... |
| 51 | + |
| 52 | + |
| 53 | +@dataclass(frozen=True, kw_only=True) |
| 54 | +class NumcodecsAdapter: |
| 55 | + _codec: Numcodec |
| 56 | + |
| 57 | + @overload |
| 58 | + def to_json(self, zarr_format: Literal[2]) -> CodecConfig_V2[str]: ... |
| 59 | + @overload |
| 60 | + def to_json(self, zarr_format: Literal[3]) -> NamedConfig[str, BaseConfig]: ... |
| 61 | + |
| 62 | + def to_json( |
| 63 | + self, zarr_format: ZarrFormat |
| 64 | + ) -> CodecConfig_V2[str] | NamedConfig[str, BaseConfig]: |
| 65 | + if zarr_format == 2: |
| 66 | + return self._codec.get_config() |
| 67 | + elif zarr_format == 3: |
| 68 | + config = self._codec.get_config() |
| 69 | + config_no_id = {k: v for k, v in config.items() if k != "id"} |
| 70 | + return {"name": config["id"], "configuration": config_no_id} |
| 71 | + raise ValueError(f"Unsupported zarr format: {zarr_format}") # pragma: no cover |
| 72 | + |
| 73 | + @classmethod |
| 74 | + def _from_json_v2(cls, data: Mapping[str, object]) -> Self: |
| 75 | + return cls(_codec=resolve_numcodec(data)) # type: ignore[arg-type] |
| 76 | + |
| 77 | + @classmethod |
| 78 | + def _from_json_v3(cls, data: Mapping[str, object]) -> Self: |
| 79 | + raise NotImplementedError( |
| 80 | + "This class does not support creating instances from JSON data for Zarr format 3." |
| 81 | + ) |
| 82 | + |
| 83 | + def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> int: |
| 84 | + raise NotImplementedError |
| 85 | + |
| 86 | + |
| 87 | +class NumcodecsBytesBytesCodec(NumcodecsAdapter, BytesBytesCodec): |
| 88 | + async def _decode_single(self, chunk_data: Buffer, chunk_spec: ArraySpec) -> Buffer: |
| 89 | + return await asyncio.to_thread( |
| 90 | + as_numpy_array_wrapper, |
| 91 | + self._codec.decode, |
| 92 | + chunk_data, |
| 93 | + chunk_spec.prototype, |
| 94 | + ) |
| 95 | + |
| 96 | + def _encode(self, chunk_bytes: Buffer, prototype: BufferPrototype) -> Buffer: |
| 97 | + encoded = self._codec.encode(chunk_bytes.as_array_like()) |
| 98 | + if isinstance(encoded, np.ndarray): # Required for checksum codecs |
| 99 | + return prototype.buffer.from_bytes(encoded.tobytes()) |
| 100 | + return prototype.buffer.from_bytes(encoded) |
| 101 | + |
| 102 | + async def _encode_single(self, chunk_data: Buffer, chunk_spec: ArraySpec) -> Buffer: |
| 103 | + return await asyncio.to_thread(self._encode, chunk_data, chunk_spec.prototype) |
| 104 | + |
| 105 | + |
| 106 | +@dataclass(kw_only=True, frozen=True) |
| 107 | +class NumcodecsArrayCodec(NumcodecsAdapter, ArrayArrayCodec): |
| 108 | + async def _decode_single(self, chunk_data: NDBuffer, chunk_spec: ArraySpec) -> NDBuffer: |
| 109 | + chunk_ndarray = chunk_data.as_ndarray_like() |
| 110 | + out = await asyncio.to_thread(self._codec.decode, chunk_ndarray) |
| 111 | + return chunk_spec.prototype.nd_buffer.from_ndarray_like(out.reshape(chunk_spec.shape)) # type: ignore[union-attr] |
| 112 | + |
| 113 | + async def _encode_single(self, chunk_data: NDBuffer, chunk_spec: ArraySpec) -> NDBuffer: |
| 114 | + chunk_ndarray = chunk_data.as_ndarray_like() |
| 115 | + out = await asyncio.to_thread(self._codec.encode, chunk_ndarray) |
| 116 | + return chunk_spec.prototype.nd_buffer.from_ndarray_like(out) # type: ignore[arg-type] |
| 117 | + |
| 118 | + |
| 119 | +@dataclass(kw_only=True, frozen=True) |
| 120 | +class NumcodecsArrayBytesCodec(NumcodecsAdapter, ArrayBytesCodec): |
| 121 | + async def _decode_single(self, chunk_data: Buffer, chunk_spec: ArraySpec) -> NDBuffer: |
| 122 | + chunk_bytes = chunk_data.to_bytes() |
| 123 | + out = await asyncio.to_thread(self._codec.decode, chunk_bytes) |
| 124 | + return chunk_spec.prototype.nd_buffer.from_ndarray_like(out.reshape(chunk_spec.shape)) |
| 125 | + |
| 126 | + async def _encode_single(self, chunk_data: NDBuffer, chunk_spec: ArraySpec) -> Buffer: |
| 127 | + chunk_ndarray = chunk_data.as_ndarray_like() |
| 128 | + out = await asyncio.to_thread(self._codec.encode, chunk_ndarray) |
| 129 | + return chunk_spec.prototype.buffer.from_bytes(out) |
0 commit comments