Skip to content

Commit a50e316

Browse files
committed
feat: add cast_value and scale_offset codecs
Defines two new codecs that together provide a v3-native replacement for the existing `numcodecs.fixedscaleoffset` codec. The `cast_value` codec requires an optional dependency on the `cast-value-rs` package.
1 parent 8f14d67 commit a50e316

File tree

6 files changed

+981
-0
lines changed

6 files changed

+981
-0
lines changed

pyproject.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ remote = [
6969
gpu = [
7070
"cupy-cuda12x",
7171
]
72+
cast-value-rs = ["cast-value-rs"]
7273
cli = ["typer"]
7374
optional = ["universal-pathlib"]
7475

@@ -190,6 +191,16 @@ run-benchmark = "pytest --benchmark-enable tests/benchmarks"
190191
serve-coverage-html = "python -m http.server -d htmlcov 8000"
191192
list-env = "pip list"
192193

194+
[tool.hatch.envs.cast-value]
195+
template = "test"
196+
features = ["cast-value-rs"]
197+
198+
[[tool.hatch.envs.cast-value.matrix]]
199+
python = ["3.12"]
200+
201+
[tool.hatch.envs.cast-value.scripts]
202+
run = "pytest tests/test_codecs/test_cast_value.py {args:}"
203+
193204
[tool.hatch.envs.gputest]
194205
template = "test"
195206
extra-dependencies = [

src/zarr/codecs/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from zarr.codecs.blosc import BloscCname, BloscCodec, BloscShuffle
44
from zarr.codecs.bytes import BytesCodec, Endian
5+
from zarr.codecs.cast_value import CastValue
56
from zarr.codecs.crc32c_ import Crc32cCodec
67
from zarr.codecs.gzip import GzipCodec
78
from zarr.codecs.numcodecs import (
@@ -27,6 +28,7 @@
2728
Zlib,
2829
Zstd,
2930
)
31+
from zarr.codecs.scale_offset import ScaleOffset
3032
from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation
3133
from zarr.codecs.transpose import TransposeCodec
3234
from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec
@@ -38,9 +40,11 @@
3840
"BloscCodec",
3941
"BloscShuffle",
4042
"BytesCodec",
43+
"CastValue",
4144
"Crc32cCodec",
4245
"Endian",
4346
"GzipCodec",
47+
"ScaleOffset",
4448
"ShardingCodec",
4549
"ShardingCodecIndexLocation",
4650
"TransposeCodec",
@@ -50,12 +54,14 @@
5054
]
5155

5256
register_codec("blosc", BloscCodec)
57+
register_codec("cast_value", CastValue)
5358
register_codec("bytes", BytesCodec)
5459

5560
# compatibility with earlier versions of ZEP1
5661
register_codec("endian", BytesCodec)
5762
register_codec("crc32c", Crc32cCodec)
5863
register_codec("gzip", GzipCodec)
64+
register_codec("scale_offset", ScaleOffset)
5965
register_codec("sharding_indexed", ShardingCodec)
6066
register_codec("zstd", ZstdCodec)
6167
register_codec("vlen-utf8", VLenUTF8Codec)

src/zarr/codecs/cast_value.py

Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
"""Cast-value array-to-array codec.
2+
3+
Value-converts array elements to a new data type during encoding,
4+
and back to the original data type during decoding, with configurable
5+
rounding, out-of-range handling, and explicit scalar mappings.
6+
7+
Requires the optional ``cast-value-rs`` package for the actual casting
8+
logic. Install it with: ``pip install cast-value-rs``.
9+
"""
10+
11+
from __future__ import annotations
12+
13+
from dataclasses import dataclass, replace
14+
from typing import TYPE_CHECKING, Literal, cast
15+
16+
import numpy as np
17+
18+
from zarr.abc.codec import ArrayArrayCodec
19+
from zarr.core.common import JSON, parse_named_configuration
20+
from zarr.core.dtype import get_data_type_from_json
21+
22+
if TYPE_CHECKING:
23+
from collections.abc import Iterable, Mapping
24+
from typing import Any, NotRequired, Self, TypedDict
25+
26+
from zarr.core.array_spec import ArraySpec
27+
from zarr.core.buffer import NDBuffer
28+
from zarr.core.chunk_grids import ChunkGrid
29+
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType
30+
31+
class ScalarMapJSON(TypedDict):
32+
encode: NotRequired[list[tuple[object, object]]]
33+
decode: NotRequired[list[tuple[object, object]]]
34+
35+
# Pre-parsed scalar map entry: (source_scalar, target_scalar)
36+
ScalarMapEntry = tuple[np.integer[Any] | np.floating[Any], np.integer[Any] | np.floating[Any]]
37+
38+
RoundingMode = Literal[
39+
"nearest-even",
40+
"towards-zero",
41+
"towards-positive",
42+
"towards-negative",
43+
"nearest-away",
44+
]
45+
46+
OutOfRangeMode = Literal["clamp", "wrap"]
47+
48+
49+
# ---------------------------------------------------------------------------
50+
# Scalar-map parsing helpers
51+
# ---------------------------------------------------------------------------
52+
53+
54+
def _extract_raw_map(data: ScalarMapJSON | None, direction: str) -> dict[str, str] | None:
55+
"""Extract raw string mapping from scalar_map JSON for 'encode' or 'decode'."""
56+
if data is None:
57+
return None
58+
raw: dict[str, str] = {}
59+
pairs: list[tuple[object, object]] = data.get(direction, []) # type: ignore[assignment]
60+
for src, tgt in pairs:
61+
raw[str(src)] = str(tgt)
62+
return raw or None
63+
64+
65+
def _parse_map_entries(
66+
mapping: Mapping[str, str],
67+
src_dtype: ZDType[TBaseDType, TBaseScalar],
68+
tgt_dtype: ZDType[TBaseDType, TBaseScalar],
69+
) -> tuple[ScalarMapEntry, ...]:
70+
"""Pre-parse a scalar map dict into a tuple of (src, tgt) pairs.
71+
72+
Each entry's source value is deserialized using ``src_dtype`` and its target
73+
value using ``tgt_dtype``, preserving full precision for both data types.
74+
"""
75+
entries: list[ScalarMapEntry] = [
76+
(
77+
src_dtype.from_json_scalar(src_str, zarr_format=3), # type: ignore[misc]
78+
tgt_dtype.from_json_scalar(tgt_str, zarr_format=3),
79+
)
80+
for src_str, tgt_str in mapping.items()
81+
]
82+
return tuple(entries)
83+
84+
85+
# ---------------------------------------------------------------------------
86+
# Backend: cast-value-rs (optional)
87+
# ---------------------------------------------------------------------------
88+
89+
try:
90+
from cast_value_rs import cast_array as _rs_cast_array
91+
92+
_HAS_RUST_BACKEND = True
93+
except ModuleNotFoundError:
94+
_HAS_RUST_BACKEND = False
95+
96+
97+
def _dtype_to_str(dtype: np.dtype) -> str: # type: ignore[type-arg]
98+
return dtype.name
99+
100+
101+
def _convert_scalar_map(
102+
entries: Iterable[ScalarMapEntry] | None,
103+
) -> list[tuple[int | float, int | float]] | None:
104+
if entries is None:
105+
return None
106+
result: list[tuple[int | float, int | float]] = []
107+
for src, tgt in entries:
108+
src_py: int | float = int(src) if isinstance(src, np.integer) else float(src)
109+
tgt_py: int | float = int(tgt) if isinstance(tgt, np.integer) else float(tgt)
110+
result.append((src_py, tgt_py))
111+
return result
112+
113+
114+
def _cast_array_rs(
115+
arr: np.ndarray, # type: ignore[type-arg]
116+
*,
117+
target_dtype: np.dtype, # type: ignore[type-arg]
118+
rounding: RoundingMode,
119+
out_of_range: OutOfRangeMode | None,
120+
scalar_map_entries: Iterable[ScalarMapEntry] | None,
121+
) -> np.ndarray: # type: ignore[type-arg]
122+
return _rs_cast_array( # type: ignore[no-any-return]
123+
arr=arr,
124+
target_dtype=_dtype_to_str(target_dtype),
125+
rounding_mode=rounding,
126+
out_of_range_mode=out_of_range,
127+
scalar_map_entries=_convert_scalar_map(scalar_map_entries),
128+
)
129+
130+
131+
# ---------------------------------------------------------------------------
132+
# Codec
133+
# ---------------------------------------------------------------------------
134+
135+
136+
@dataclass(frozen=True)
137+
class CastValue(ArrayArrayCodec):
138+
"""Cast-value array-to-array codec.
139+
140+
Value-converts array elements to a new data type during encoding,
141+
and back to the original data type during decoding.
142+
143+
Requires the ``cast-value-rs`` package for the actual casting logic.
144+
145+
Parameters
146+
----------
147+
data_type : str
148+
Target zarr v3 data type name (e.g. "uint8", "float32").
149+
rounding : RoundingMode
150+
How to round when exact representation is impossible. Default is "nearest-even".
151+
out_of_range : OutOfRangeMode or None
152+
What to do when a value is outside the target's range.
153+
None means error. "clamp" clips to range. "wrap" uses modular arithmetic
154+
(only valid for integer types).
155+
scalar_map : dict or None
156+
Explicit mapping from input scalars to output scalars.
157+
158+
References
159+
----------
160+
161+
- The `cast_value` codec spec: https://github.com/zarr-developers/zarr-extensions/tree/main/codecs/cast_value
162+
"""
163+
164+
is_fixed_size = True
165+
166+
dtype: ZDType[TBaseDType, TBaseScalar]
167+
rounding: RoundingMode
168+
out_of_range: OutOfRangeMode | None
169+
scalar_map: ScalarMapJSON | None
170+
171+
def __init__(
172+
self,
173+
*,
174+
data_type: str | ZDType[TBaseDType, TBaseScalar],
175+
rounding: RoundingMode = "nearest-even",
176+
out_of_range: OutOfRangeMode | None = None,
177+
scalar_map: ScalarMapJSON | None = None,
178+
) -> None:
179+
if isinstance(data_type, str):
180+
zdtype = get_data_type_from_json(data_type, zarr_format=3)
181+
else:
182+
zdtype = data_type
183+
object.__setattr__(self, "dtype", zdtype)
184+
object.__setattr__(self, "rounding", rounding)
185+
object.__setattr__(self, "out_of_range", out_of_range)
186+
object.__setattr__(self, "scalar_map", scalar_map)
187+
188+
@classmethod
189+
def from_dict(cls, data: dict[str, JSON]) -> Self:
190+
_, configuration_parsed = parse_named_configuration(
191+
data, "cast_value", require_configuration=True
192+
)
193+
return cls(**configuration_parsed) # type: ignore[arg-type]
194+
195+
def to_dict(self) -> dict[str, JSON]:
196+
config: dict[str, JSON] = {"data_type": cast("JSON", self.dtype.to_json(zarr_format=3))}
197+
if self.rounding != "nearest-even":
198+
config["rounding"] = self.rounding
199+
if self.out_of_range is not None:
200+
config["out_of_range"] = self.out_of_range
201+
if self.scalar_map is not None:
202+
config["scalar_map"] = cast("JSON", self.scalar_map)
203+
return {"name": "cast_value", "configuration": config}
204+
205+
def validate(
206+
self,
207+
*,
208+
shape: tuple[int, ...],
209+
dtype: ZDType[TBaseDType, TBaseScalar],
210+
chunk_grid: ChunkGrid,
211+
) -> None:
212+
source_native = dtype.to_native_dtype()
213+
target_native = self.dtype.to_native_dtype()
214+
for label, dt in [("source", source_native), ("target", target_native)]:
215+
if not np.issubdtype(dt, np.integer) and not np.issubdtype(dt, np.floating):
216+
raise ValueError(
217+
f"The cast_value codec only supports integer and floating-point data types. "
218+
f"Got {label} dtype {dt}."
219+
)
220+
if self.out_of_range == "wrap" and not np.issubdtype(target_native, np.integer):
221+
raise ValueError("out_of_range='wrap' is only valid for integer target types.")
222+
223+
def _do_cast(
224+
self,
225+
arr: np.ndarray, # type: ignore[type-arg]
226+
*,
227+
target_dtype: np.dtype, # type: ignore[type-arg]
228+
scalar_map_entries: Iterable[ScalarMapEntry] | None,
229+
) -> np.ndarray: # type: ignore[type-arg]
230+
if not _HAS_RUST_BACKEND:
231+
raise ImportError(
232+
"The cast_value codec requires the 'cast-value-rs' package. "
233+
"Install it with: pip install cast-value-rs"
234+
)
235+
return _cast_array_rs(
236+
arr,
237+
target_dtype=target_dtype,
238+
rounding=self.rounding,
239+
out_of_range=self.out_of_range,
240+
scalar_map_entries=scalar_map_entries,
241+
)
242+
243+
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
244+
"""
245+
Update the fill value of the output spec by applying casting procedure.
246+
"""
247+
target_zdtype = self.dtype
248+
target_native = target_zdtype.to_native_dtype()
249+
source_native = chunk_spec.dtype.to_native_dtype()
250+
251+
fill = chunk_spec.fill_value
252+
fill_arr = np.array([fill], dtype=source_native)
253+
254+
encode_raw = _extract_raw_map(self.scalar_map, "encode")
255+
encode_entries = (
256+
_parse_map_entries(encode_raw, chunk_spec.dtype, self.dtype) if encode_raw else None
257+
)
258+
259+
new_fill_arr = self._do_cast(
260+
fill_arr, target_dtype=target_native, scalar_map_entries=encode_entries
261+
)
262+
new_fill = target_native.type(new_fill_arr[0])
263+
264+
return replace(chunk_spec, dtype=target_zdtype, fill_value=new_fill)
265+
266+
def _encode_sync(
267+
self,
268+
chunk_array: NDBuffer,
269+
_chunk_spec: ArraySpec,
270+
) -> NDBuffer | None:
271+
arr = chunk_array.as_ndarray_like()
272+
target_native = self.dtype.to_native_dtype()
273+
274+
encode_raw = _extract_raw_map(self.scalar_map, "encode")
275+
encode_entries = (
276+
_parse_map_entries(encode_raw, _chunk_spec.dtype, self.dtype) if encode_raw else None
277+
)
278+
279+
result = self._do_cast(
280+
np.asarray(arr), target_dtype=target_native, scalar_map_entries=encode_entries
281+
)
282+
return chunk_array.__class__.from_ndarray_like(result)
283+
284+
async def _encode_single(
285+
self,
286+
chunk_data: NDBuffer,
287+
chunk_spec: ArraySpec,
288+
) -> NDBuffer | None:
289+
return self._encode_sync(chunk_data, chunk_spec)
290+
291+
def _decode_sync(
292+
self,
293+
chunk_array: NDBuffer,
294+
chunk_spec: ArraySpec,
295+
) -> NDBuffer:
296+
arr = chunk_array.as_ndarray_like()
297+
target_native = chunk_spec.dtype.to_native_dtype()
298+
299+
decode_raw = _extract_raw_map(self.scalar_map, "decode")
300+
decode_entries = (
301+
_parse_map_entries(decode_raw, self.dtype, chunk_spec.dtype) if decode_raw else None
302+
)
303+
304+
result = self._do_cast(
305+
np.asarray(arr), target_dtype=target_native, scalar_map_entries=decode_entries
306+
)
307+
return chunk_array.__class__.from_ndarray_like(result)
308+
309+
async def _decode_single(
310+
self,
311+
chunk_data: NDBuffer,
312+
chunk_spec: ArraySpec,
313+
) -> NDBuffer:
314+
return self._decode_sync(chunk_data, chunk_spec)
315+
316+
def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> int:
317+
source_itemsize = chunk_spec.dtype.to_native_dtype().itemsize
318+
target_itemsize = self.dtype.to_native_dtype().itemsize
319+
if source_itemsize == 0 or target_itemsize == 0:
320+
raise ValueError(
321+
"cast_value codec requires fixed-size data types. "
322+
f"Got source itemsize={source_itemsize}, target itemsize={target_itemsize}."
323+
)
324+
num_elements = input_byte_length // source_itemsize
325+
return num_elements * target_itemsize

0 commit comments

Comments
 (0)