|
| 1 | +"""Cast-value array-to-array codec. |
| 2 | +
|
| 3 | +Value-converts array elements to a new data type during encoding, |
| 4 | +and back to the original data type during decoding, with configurable |
| 5 | +rounding, out-of-range handling, and explicit scalar mappings. |
| 6 | +
|
| 7 | +Requires the optional ``cast-value-rs`` package for the actual casting |
| 8 | +logic. Install it with: ``pip install cast-value-rs``. |
| 9 | +""" |
| 10 | + |
| 11 | +from __future__ import annotations |
| 12 | + |
| 13 | +from dataclasses import dataclass, replace |
| 14 | +from typing import TYPE_CHECKING, Literal, cast |
| 15 | + |
| 16 | +import numpy as np |
| 17 | + |
| 18 | +from zarr.abc.codec import ArrayArrayCodec |
| 19 | +from zarr.core.common import JSON, parse_named_configuration |
| 20 | +from zarr.core.dtype import get_data_type_from_json |
| 21 | + |
| 22 | +if TYPE_CHECKING: |
| 23 | + from collections.abc import Iterable, Mapping |
| 24 | + from typing import Any, NotRequired, Self, TypedDict |
| 25 | + |
| 26 | + from zarr.core.array_spec import ArraySpec |
| 27 | + from zarr.core.buffer import NDBuffer |
| 28 | + from zarr.core.chunk_grids import ChunkGrid |
| 29 | + from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType |
| 30 | + |
| 31 | + class ScalarMapJSON(TypedDict): |
| 32 | + encode: NotRequired[list[tuple[object, object]]] |
| 33 | + decode: NotRequired[list[tuple[object, object]]] |
| 34 | + |
| 35 | + # Pre-parsed scalar map entry: (source_scalar, target_scalar) |
| 36 | + ScalarMapEntry = tuple[np.integer[Any] | np.floating[Any], np.integer[Any] | np.floating[Any]] |
| 37 | + |
| 38 | +RoundingMode = Literal[ |
| 39 | + "nearest-even", |
| 40 | + "towards-zero", |
| 41 | + "towards-positive", |
| 42 | + "towards-negative", |
| 43 | + "nearest-away", |
| 44 | +] |
| 45 | + |
| 46 | +OutOfRangeMode = Literal["clamp", "wrap"] |
| 47 | + |
| 48 | + |
| 49 | +# --------------------------------------------------------------------------- |
| 50 | +# Scalar-map parsing helpers |
| 51 | +# --------------------------------------------------------------------------- |
| 52 | + |
| 53 | + |
| 54 | +def _extract_raw_map(data: ScalarMapJSON | None, direction: str) -> dict[str, str] | None: |
| 55 | + """Extract raw string mapping from scalar_map JSON for 'encode' or 'decode'.""" |
| 56 | + if data is None: |
| 57 | + return None |
| 58 | + raw: dict[str, str] = {} |
| 59 | + pairs: list[tuple[object, object]] = data.get(direction, []) # type: ignore[assignment] |
| 60 | + for src, tgt in pairs: |
| 61 | + raw[str(src)] = str(tgt) |
| 62 | + return raw or None |
| 63 | + |
| 64 | + |
| 65 | +def _parse_map_entries( |
| 66 | + mapping: Mapping[str, str], |
| 67 | + src_dtype: ZDType[TBaseDType, TBaseScalar], |
| 68 | + tgt_dtype: ZDType[TBaseDType, TBaseScalar], |
| 69 | +) -> tuple[ScalarMapEntry, ...]: |
| 70 | + """Pre-parse a scalar map dict into a tuple of (src, tgt) pairs. |
| 71 | +
|
| 72 | + Each entry's source value is deserialized using ``src_dtype`` and its target |
| 73 | + value using ``tgt_dtype``, preserving full precision for both data types. |
| 74 | + """ |
| 75 | + entries: list[ScalarMapEntry] = [ |
| 76 | + ( |
| 77 | + src_dtype.from_json_scalar(src_str, zarr_format=3), # type: ignore[misc] |
| 78 | + tgt_dtype.from_json_scalar(tgt_str, zarr_format=3), |
| 79 | + ) |
| 80 | + for src_str, tgt_str in mapping.items() |
| 81 | + ] |
| 82 | + return tuple(entries) |
| 83 | + |
| 84 | + |
| 85 | +# --------------------------------------------------------------------------- |
| 86 | +# Backend: cast-value-rs (optional) |
| 87 | +# --------------------------------------------------------------------------- |
| 88 | + |
| 89 | +try: |
| 90 | + from cast_value_rs import cast_array as _rs_cast_array |
| 91 | + |
| 92 | + _HAS_RUST_BACKEND = True |
| 93 | +except ModuleNotFoundError: |
| 94 | + _HAS_RUST_BACKEND = False |
| 95 | + |
| 96 | + |
| 97 | +def _dtype_to_str(dtype: np.dtype) -> str: # type: ignore[type-arg] |
| 98 | + return dtype.name |
| 99 | + |
| 100 | + |
| 101 | +def _convert_scalar_map( |
| 102 | + entries: Iterable[ScalarMapEntry] | None, |
| 103 | +) -> list[tuple[int | float, int | float]] | None: |
| 104 | + if entries is None: |
| 105 | + return None |
| 106 | + result: list[tuple[int | float, int | float]] = [] |
| 107 | + for src, tgt in entries: |
| 108 | + src_py: int | float = int(src) if isinstance(src, np.integer) else float(src) |
| 109 | + tgt_py: int | float = int(tgt) if isinstance(tgt, np.integer) else float(tgt) |
| 110 | + result.append((src_py, tgt_py)) |
| 111 | + return result |
| 112 | + |
| 113 | + |
| 114 | +def _cast_array_rs( |
| 115 | + arr: np.ndarray, # type: ignore[type-arg] |
| 116 | + *, |
| 117 | + target_dtype: np.dtype, # type: ignore[type-arg] |
| 118 | + rounding: RoundingMode, |
| 119 | + out_of_range: OutOfRangeMode | None, |
| 120 | + scalar_map_entries: Iterable[ScalarMapEntry] | None, |
| 121 | +) -> np.ndarray: # type: ignore[type-arg] |
| 122 | + return _rs_cast_array( # type: ignore[no-any-return] |
| 123 | + arr=arr, |
| 124 | + target_dtype=_dtype_to_str(target_dtype), |
| 125 | + rounding_mode=rounding, |
| 126 | + out_of_range_mode=out_of_range, |
| 127 | + scalar_map_entries=_convert_scalar_map(scalar_map_entries), |
| 128 | + ) |
| 129 | + |
| 130 | + |
| 131 | +# --------------------------------------------------------------------------- |
| 132 | +# Codec |
| 133 | +# --------------------------------------------------------------------------- |
| 134 | + |
| 135 | + |
| 136 | +@dataclass(frozen=True) |
| 137 | +class CastValue(ArrayArrayCodec): |
| 138 | + """Cast-value array-to-array codec. |
| 139 | +
|
| 140 | + Value-converts array elements to a new data type during encoding, |
| 141 | + and back to the original data type during decoding. |
| 142 | +
|
| 143 | + Requires the ``cast-value-rs`` package for the actual casting logic. |
| 144 | +
|
| 145 | + Parameters |
| 146 | + ---------- |
| 147 | + data_type : str |
| 148 | + Target zarr v3 data type name (e.g. "uint8", "float32"). |
| 149 | + rounding : RoundingMode |
| 150 | + How to round when exact representation is impossible. Default is "nearest-even". |
| 151 | + out_of_range : OutOfRangeMode or None |
| 152 | + What to do when a value is outside the target's range. |
| 153 | + None means error. "clamp" clips to range. "wrap" uses modular arithmetic |
| 154 | + (only valid for integer types). |
| 155 | + scalar_map : dict or None |
| 156 | + Explicit mapping from input scalars to output scalars. |
| 157 | +
|
| 158 | + References |
| 159 | + ---------- |
| 160 | +
|
| 161 | + - The `cast_value` codec spec: https://github.com/zarr-developers/zarr-extensions/tree/main/codecs/cast_value |
| 162 | + """ |
| 163 | + |
| 164 | + is_fixed_size = True |
| 165 | + |
| 166 | + dtype: ZDType[TBaseDType, TBaseScalar] |
| 167 | + rounding: RoundingMode |
| 168 | + out_of_range: OutOfRangeMode | None |
| 169 | + scalar_map: ScalarMapJSON | None |
| 170 | + |
| 171 | + def __init__( |
| 172 | + self, |
| 173 | + *, |
| 174 | + data_type: str | ZDType[TBaseDType, TBaseScalar], |
| 175 | + rounding: RoundingMode = "nearest-even", |
| 176 | + out_of_range: OutOfRangeMode | None = None, |
| 177 | + scalar_map: ScalarMapJSON | None = None, |
| 178 | + ) -> None: |
| 179 | + if isinstance(data_type, str): |
| 180 | + zdtype = get_data_type_from_json(data_type, zarr_format=3) |
| 181 | + else: |
| 182 | + zdtype = data_type |
| 183 | + object.__setattr__(self, "dtype", zdtype) |
| 184 | + object.__setattr__(self, "rounding", rounding) |
| 185 | + object.__setattr__(self, "out_of_range", out_of_range) |
| 186 | + object.__setattr__(self, "scalar_map", scalar_map) |
| 187 | + |
| 188 | + @classmethod |
| 189 | + def from_dict(cls, data: dict[str, JSON]) -> Self: |
| 190 | + _, configuration_parsed = parse_named_configuration( |
| 191 | + data, "cast_value", require_configuration=True |
| 192 | + ) |
| 193 | + return cls(**configuration_parsed) # type: ignore[arg-type] |
| 194 | + |
| 195 | + def to_dict(self) -> dict[str, JSON]: |
| 196 | + config: dict[str, JSON] = {"data_type": cast("JSON", self.dtype.to_json(zarr_format=3))} |
| 197 | + if self.rounding != "nearest-even": |
| 198 | + config["rounding"] = self.rounding |
| 199 | + if self.out_of_range is not None: |
| 200 | + config["out_of_range"] = self.out_of_range |
| 201 | + if self.scalar_map is not None: |
| 202 | + config["scalar_map"] = cast("JSON", self.scalar_map) |
| 203 | + return {"name": "cast_value", "configuration": config} |
| 204 | + |
| 205 | + def validate( |
| 206 | + self, |
| 207 | + *, |
| 208 | + shape: tuple[int, ...], |
| 209 | + dtype: ZDType[TBaseDType, TBaseScalar], |
| 210 | + chunk_grid: ChunkGrid, |
| 211 | + ) -> None: |
| 212 | + source_native = dtype.to_native_dtype() |
| 213 | + target_native = self.dtype.to_native_dtype() |
| 214 | + for label, dt in [("source", source_native), ("target", target_native)]: |
| 215 | + if not np.issubdtype(dt, np.integer) and not np.issubdtype(dt, np.floating): |
| 216 | + raise ValueError( |
| 217 | + f"The cast_value codec only supports integer and floating-point data types. " |
| 218 | + f"Got {label} dtype {dt}." |
| 219 | + ) |
| 220 | + if self.out_of_range == "wrap" and not np.issubdtype(target_native, np.integer): |
| 221 | + raise ValueError("out_of_range='wrap' is only valid for integer target types.") |
| 222 | + |
| 223 | + def _do_cast( |
| 224 | + self, |
| 225 | + arr: np.ndarray, # type: ignore[type-arg] |
| 226 | + *, |
| 227 | + target_dtype: np.dtype, # type: ignore[type-arg] |
| 228 | + scalar_map_entries: Iterable[ScalarMapEntry] | None, |
| 229 | + ) -> np.ndarray: # type: ignore[type-arg] |
| 230 | + if not _HAS_RUST_BACKEND: |
| 231 | + raise ImportError( |
| 232 | + "The cast_value codec requires the 'cast-value-rs' package. " |
| 233 | + "Install it with: pip install cast-value-rs" |
| 234 | + ) |
| 235 | + return _cast_array_rs( |
| 236 | + arr, |
| 237 | + target_dtype=target_dtype, |
| 238 | + rounding=self.rounding, |
| 239 | + out_of_range=self.out_of_range, |
| 240 | + scalar_map_entries=scalar_map_entries, |
| 241 | + ) |
| 242 | + |
| 243 | + def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: |
| 244 | + """ |
| 245 | + Update the fill value of the output spec by applying casting procedure. |
| 246 | + """ |
| 247 | + target_zdtype = self.dtype |
| 248 | + target_native = target_zdtype.to_native_dtype() |
| 249 | + source_native = chunk_spec.dtype.to_native_dtype() |
| 250 | + |
| 251 | + fill = chunk_spec.fill_value |
| 252 | + fill_arr = np.array([fill], dtype=source_native) |
| 253 | + |
| 254 | + encode_raw = _extract_raw_map(self.scalar_map, "encode") |
| 255 | + encode_entries = ( |
| 256 | + _parse_map_entries(encode_raw, chunk_spec.dtype, self.dtype) if encode_raw else None |
| 257 | + ) |
| 258 | + |
| 259 | + new_fill_arr = self._do_cast( |
| 260 | + fill_arr, target_dtype=target_native, scalar_map_entries=encode_entries |
| 261 | + ) |
| 262 | + new_fill = target_native.type(new_fill_arr[0]) |
| 263 | + |
| 264 | + return replace(chunk_spec, dtype=target_zdtype, fill_value=new_fill) |
| 265 | + |
| 266 | + def _encode_sync( |
| 267 | + self, |
| 268 | + chunk_array: NDBuffer, |
| 269 | + _chunk_spec: ArraySpec, |
| 270 | + ) -> NDBuffer | None: |
| 271 | + arr = chunk_array.as_ndarray_like() |
| 272 | + target_native = self.dtype.to_native_dtype() |
| 273 | + |
| 274 | + encode_raw = _extract_raw_map(self.scalar_map, "encode") |
| 275 | + encode_entries = ( |
| 276 | + _parse_map_entries(encode_raw, _chunk_spec.dtype, self.dtype) if encode_raw else None |
| 277 | + ) |
| 278 | + |
| 279 | + result = self._do_cast( |
| 280 | + np.asarray(arr), target_dtype=target_native, scalar_map_entries=encode_entries |
| 281 | + ) |
| 282 | + return chunk_array.__class__.from_ndarray_like(result) |
| 283 | + |
| 284 | + async def _encode_single( |
| 285 | + self, |
| 286 | + chunk_data: NDBuffer, |
| 287 | + chunk_spec: ArraySpec, |
| 288 | + ) -> NDBuffer | None: |
| 289 | + return self._encode_sync(chunk_data, chunk_spec) |
| 290 | + |
| 291 | + def _decode_sync( |
| 292 | + self, |
| 293 | + chunk_array: NDBuffer, |
| 294 | + chunk_spec: ArraySpec, |
| 295 | + ) -> NDBuffer: |
| 296 | + arr = chunk_array.as_ndarray_like() |
| 297 | + target_native = chunk_spec.dtype.to_native_dtype() |
| 298 | + |
| 299 | + decode_raw = _extract_raw_map(self.scalar_map, "decode") |
| 300 | + decode_entries = ( |
| 301 | + _parse_map_entries(decode_raw, self.dtype, chunk_spec.dtype) if decode_raw else None |
| 302 | + ) |
| 303 | + |
| 304 | + result = self._do_cast( |
| 305 | + np.asarray(arr), target_dtype=target_native, scalar_map_entries=decode_entries |
| 306 | + ) |
| 307 | + return chunk_array.__class__.from_ndarray_like(result) |
| 308 | + |
| 309 | + async def _decode_single( |
| 310 | + self, |
| 311 | + chunk_data: NDBuffer, |
| 312 | + chunk_spec: ArraySpec, |
| 313 | + ) -> NDBuffer: |
| 314 | + return self._decode_sync(chunk_data, chunk_spec) |
| 315 | + |
| 316 | + def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> int: |
| 317 | + source_itemsize = chunk_spec.dtype.to_native_dtype().itemsize |
| 318 | + target_itemsize = self.dtype.to_native_dtype().itemsize |
| 319 | + if source_itemsize == 0 or target_itemsize == 0: |
| 320 | + raise ValueError( |
| 321 | + "cast_value codec requires fixed-size data types. " |
| 322 | + f"Got source itemsize={source_itemsize}, target itemsize={target_itemsize}." |
| 323 | + ) |
| 324 | + num_elements = input_byte_length // source_itemsize |
| 325 | + return num_elements * target_itemsize |
0 commit comments