Skip to content

Commit b123813

Browse files
committed
fix: make encode / decode stricter about dtypes
1 parent 2e8d644 commit b123813

2 files changed

Lines changed: 481 additions & 41 deletions

File tree

src/zarr/codecs/scale_offset.py

Lines changed: 288 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass, replace
4-
from typing import TYPE_CHECKING
4+
from typing import TYPE_CHECKING, Any, cast
55

66
import numpy as np
7+
import numpy.typing as npt
78

89
from zarr.abc.codec import ArrayArrayCodec
910
from zarr.core.common import JSON, parse_named_configuration
@@ -17,14 +18,278 @@
1718
from zarr.core.metadata.v3 import ChunkGridMetadata
1819

1920

21+
_WIDE_INT = np.dtype(np.int64)
22+
23+
24+
def _encode_fits_natively(dtype: np.dtype[Any], offset: int, scale: int) -> bool:
25+
"""Static range proof: is ``(x - offset) * scale`` always in range for every ``x`` in dtype?
26+
27+
Uses Python ints (unbounded) to avoid overflow in the proof itself.
28+
"""
29+
info = np.iinfo(dtype)
30+
d_lo = int(info.min) - offset
31+
d_hi = int(info.max) - offset
32+
# Taking min/max of both products handles negative scale without a sign branch.
33+
products = (d_lo * scale, d_hi * scale)
34+
lo, hi = min(products), max(products)
35+
return info.min <= lo and hi <= info.max
36+
37+
38+
def _decode_fits_natively(dtype: np.dtype[Any], offset: int, scale: int) -> bool:
39+
"""Static range proof for decode: is ``x // scale + offset`` always in range?"""
40+
info = np.iinfo(dtype)
41+
# x // scale is bounded by the extremes of x / scale (integer division stays within that range)
42+
if scale > 0:
43+
q_lo, q_hi = int(info.min) // scale, int(info.max) // scale
44+
else:
45+
q_lo, q_hi = int(info.max) // scale, int(info.min) // scale
46+
lo, hi = q_lo + offset, q_hi + offset
47+
return info.min <= lo and hi <= info.max
48+
49+
50+
def _check_int_range(
51+
values: npt.NDArray[np.integer[Any]], target: np.dtype[np.integer[Any]]
52+
) -> None:
53+
"""Raise if any value is outside the representable range of ``target``.
54+
55+
Uses a single min/max pass instead of two ``np.any`` passes.
56+
"""
57+
info = np.iinfo(target)
58+
lo, hi = values.min(), values.max()
59+
if lo < info.min or hi > info.max:
60+
raise ValueError(
61+
f"scale_offset produced a value outside the range of dtype {target} "
62+
f"[{info.min}, {info.max}]."
63+
)
64+
65+
66+
def _check_exact_division(
67+
arr: npt.NDArray[np.integer[Any]], scale: np.integer[Any], scale_repr: object
68+
) -> None:
69+
"""Raise ValueError if ``arr`` has any element not exactly divisible by ``scale``."""
70+
if np.any(arr % scale):
71+
raise ValueError(
72+
f"scale_offset decode produced a non-zero remainder when dividing by "
73+
f"scale={scale_repr!r}; result is not exactly representable in dtype {arr.dtype}."
74+
)
75+
76+
77+
def _encode_int_native(
78+
arr: npt.NDArray[np.integer[Any]], offset: np.integer[Any], scale: np.integer[Any]
79+
) -> npt.NDArray[np.integer[Any]]:
80+
"""Compute ``(arr - offset) * scale`` directly in ``arr.dtype``.
81+
82+
This is the fast path; it exists only as a separate function to make the contract with
83+
``_encode_fits_natively`` explicit: the caller must have already proved that no ``x`` in
84+
``arr.dtype``'s range can overflow, so we can skip widening and range-checking entirely.
85+
Using it without that proof would silently wrap on overflow.
86+
"""
87+
return cast("npt.NDArray[np.integer[Any]]", (arr - offset) * scale)
88+
89+
90+
def _encode_int_widened(
91+
arr: npt.NDArray[np.integer[Any]], offset: np.integer[Any], scale: np.integer[Any]
92+
) -> npt.NDArray[np.integer[Any]]:
93+
"""Overflow-checked integer encode for int8..int64 and uint8..uint32.
94+
95+
Exists because numpy integer arithmetic silently wraps on overflow, which the spec
96+
forbids. We widen to int64, perform the arithmetic there (int64 holds the product of any
97+
two values from these dtypes), range-check against the target dtype, then cast back.
98+
uint64 cannot use this path because its range exceeds int64 — see ``_encode_uint64``.
99+
"""
100+
wide_arr = arr.astype(_WIDE_INT, copy=False)
101+
result = (wide_arr - _WIDE_INT.type(offset)) * _WIDE_INT.type(scale)
102+
_check_int_range(result, arr.dtype)
103+
return result.astype(arr.dtype, copy=False)
104+
105+
106+
def _encode_float(
107+
arr: npt.NDArray[np.floating[Any]], offset: np.floating[Any], scale: np.floating[Any]
108+
) -> npt.NDArray[np.floating[Any]]:
109+
"""Encode float arrays in-dtype, guarding only against silent promotion.
110+
111+
Float arithmetic doesn't need widening — float64 is already the widest supported dtype,
112+
and ``inf``/``nan`` from overflow are representable IEEE 754 values, so no range check is
113+
required by the spec. The one thing that can still go wrong is numpy promoting the
114+
result to a wider float dtype (e.g. float32 * float64 scalar -> float64), which would
115+
violate the spec's "arithmetic semantics of the input array's data type" clause.
116+
"""
117+
result = cast("npt.NDArray[np.floating[Any]]", (arr - offset) * scale)
118+
if result.dtype != arr.dtype:
119+
raise ValueError(
120+
f"scale_offset changed dtype from {arr.dtype} to {result.dtype}. "
121+
f"Arithmetic must preserve the data type."
122+
)
123+
return result
124+
125+
126+
def _check_py_int_range(
127+
result: np.ndarray[tuple[Any, ...], np.dtype[Any]],
128+
target: np.dtype[np.unsignedinteger[Any]],
129+
) -> None:
130+
"""Range-check an ``object``-dtype ndarray holding Python ints against ``target``'s iinfo.
131+
132+
Exists as a uint64-specific counterpart to ``_check_int_range``. That one compares numpy
133+
integers against ``iinfo``; here the values are unbounded Python ints produced by
134+
``_encode_uint64`` / ``_decode_uint64``, so we rely on Python's arbitrary-precision
135+
comparison to detect values outside the target dtype's range.
136+
"""
137+
info = np.iinfo(target)
138+
# np.min/np.max on an object array returns a Python int (which compares correctly with iinfo).
139+
# Works uniformly for 0-d arrays where .flat iteration is awkward.
140+
lo = np.min(result)
141+
hi = np.max(result)
142+
if lo < int(info.min) or hi > int(info.max):
143+
raise ValueError(
144+
f"scale_offset produced a value outside the range of dtype {target} "
145+
f"[{info.min}, {info.max}]."
146+
)
147+
148+
149+
def _encode_uint64(
150+
arr: npt.NDArray[np.unsignedinteger[Any]], offset: int, scale: int
151+
) -> npt.NDArray[np.unsignedinteger[Any]]:
152+
"""Encode uint64 via Python-int arithmetic in an ``object``-dtype array.
153+
154+
Exists because uint64's range [0, 2**64) exceeds int64, so the int64 widening used by
155+
``_encode_int_widened`` would itself overflow. Python ints are unbounded, so computing
156+
via ``object`` dtype is correct by construction. The trade-off is speed: object-dtype
157+
arithmetic is interpreted per element and is roughly 10x slower than ufunc paths.
158+
"""
159+
obj = arr.astype(object, copy=False)
160+
# np.asarray restores ndarray-ness in the 0-d/scalar edge case.
161+
result = np.asarray((obj - offset) * scale, dtype=object)
162+
_check_py_int_range(result, arr.dtype)
163+
return cast("npt.NDArray[np.unsignedinteger[Any]]", result.astype(arr.dtype, copy=False))
164+
165+
166+
def _decode_uint64(
167+
arr: npt.NDArray[np.unsignedinteger[Any]], offset: int, scale: int
168+
) -> npt.NDArray[np.unsignedinteger[Any]]:
169+
"""Decode uint64 via Python-int arithmetic. See ``_encode_uint64`` for why."""
170+
obj = arr.astype(object, copy=False)
171+
result = np.asarray((obj // scale) + offset, dtype=object)
172+
_check_py_int_range(result, arr.dtype)
173+
return cast("npt.NDArray[np.unsignedinteger[Any]]", result.astype(arr.dtype, copy=False))
174+
175+
176+
def _decode_int_native(
177+
arr: npt.NDArray[np.integer[Any]], offset: np.integer[Any], scale: np.integer[Any]
178+
) -> npt.NDArray[np.integer[Any]]:
179+
"""Compute ``arr // scale + offset`` directly in ``arr.dtype``.
180+
181+
Fast-path counterpart to ``_encode_int_native``; same contract. Caller must have proved
182+
via ``_decode_fits_natively`` that the result can't overflow. Divisibility is checked
183+
upstream in ``_decode`` before this is called, so ``//`` is exact here.
184+
"""
185+
return cast("npt.NDArray[np.integer[Any]]", (arr // scale) + offset)
186+
187+
188+
def _decode_int_widened(
189+
arr: npt.NDArray[np.integer[Any]], offset: np.integer[Any], scale: np.integer[Any]
190+
) -> npt.NDArray[np.integer[Any]]:
191+
"""Overflow-checked integer decode for int8..int64 and uint8..uint32.
192+
193+
Counterpart to ``_encode_int_widened``. Widens to int64 so the addition of ``offset``
194+
after division can't silently wrap, then range-checks against the target dtype.
195+
"""
196+
wide_arr = arr.astype(_WIDE_INT, copy=False)
197+
result = (wide_arr // _WIDE_INT.type(scale)) + _WIDE_INT.type(offset)
198+
_check_int_range(result, arr.dtype)
199+
return result.astype(arr.dtype, copy=False)
200+
201+
202+
def _decode_float(
203+
arr: npt.NDArray[np.floating[Any]], offset: np.floating[Any], scale: np.floating[Any]
204+
) -> npt.NDArray[np.floating[Any]]:
205+
"""Decode float arrays in-dtype, guarding only against silent promotion.
206+
207+
Counterpart to ``_encode_float``; same reasoning. ``arr / scale`` is true division and
208+
always well-defined for floats (including ``0/0 = nan`` and ``x/0 = ±inf``), so no range
209+
or exactness check is needed.
210+
"""
211+
result = cast("npt.NDArray[np.floating[Any]]", (arr / scale) + offset)
212+
if result.dtype != arr.dtype:
213+
raise ValueError(
214+
f"scale_offset changed dtype from {arr.dtype} to {result.dtype}. "
215+
f"Arithmetic must preserve the data type."
216+
)
217+
return result
218+
219+
220+
def _encode(
221+
arr: np.ndarray[tuple[Any, ...], np.dtype[Any]],
222+
offset: np.generic,
223+
scale: np.generic,
224+
) -> np.ndarray[tuple[Any, ...], np.dtype[Any]]:
225+
"""Compute ``(arr - offset) * scale`` without silent overflow, returning ``arr.dtype``."""
226+
# uint64 is split out first because its full range (up to 2**64-1) doesn't fit in int64,
227+
# so the widening strategy used for every other integer dtype would itself overflow.
228+
if arr.dtype == np.uint64:
229+
u_arr = cast("npt.NDArray[np.unsignedinteger[Any]]", arr)
230+
return _encode_uint64(u_arr, int(offset), int(scale))
231+
if np.issubdtype(arr.dtype, np.integer):
232+
i_arr = cast("npt.NDArray[np.integer[Any]]", arr)
233+
i_offset = cast("np.integer[Any]", offset)
234+
i_scale = cast("np.integer[Any]", scale)
235+
# Fast path: if a static proof shows no ``x`` in the dtype's range can overflow,
236+
# skip the int64 widening and run the arithmetic directly in the input dtype.
237+
if _encode_fits_natively(arr.dtype, int(offset), int(scale)):
238+
return _encode_int_native(i_arr, i_offset, i_scale)
239+
return _encode_int_widened(i_arr, i_offset, i_scale)
240+
# Float path: arithmetic stays in-dtype (no widening); only guard against numpy
241+
# silently promoting a narrower float to a wider one via scalar type mismatch.
242+
f_arr = cast("npt.NDArray[np.floating[Any]]", arr)
243+
f_offset = cast("np.floating[Any]", offset)
244+
f_scale = cast("np.floating[Any]", scale)
245+
return _encode_float(f_arr, f_offset, f_scale)
246+
247+
248+
def _decode(
249+
arr: np.ndarray[tuple[Any, ...], np.dtype[Any]],
250+
offset: np.generic,
251+
scale: np.generic,
252+
*,
253+
scale_repr: object,
254+
) -> np.ndarray[tuple[Any, ...], np.dtype[Any]]:
255+
"""Compute ``arr / scale + offset`` without silent overflow, returning ``arr.dtype``."""
256+
# uint64: same reasoning as _encode — its range exceeds int64, so the Python-int path is the
257+
# only correct option. Exactness check runs first so non-divisible inputs fail before the
258+
# slower object-dtype arithmetic.
259+
if arr.dtype == np.uint64:
260+
u_arr = cast("npt.NDArray[np.unsignedinteger[Any]]", arr)
261+
_check_exact_division(u_arr, cast("np.integer[Any]", scale), scale_repr)
262+
return _decode_uint64(u_arr, int(offset), int(scale))
263+
if np.issubdtype(arr.dtype, np.integer):
264+
i_arr = cast("npt.NDArray[np.integer[Any]]", arr)
265+
i_offset = cast("np.integer[Any]", offset)
266+
i_scale = cast("np.integer[Any]", scale)
267+
# The spec requires decode to use true division and error if the result isn't
268+
# representable. For integers that means the remainder must be zero; if any element
269+
# isn't exactly divisible we fail here rather than silently truncating via //.
270+
_check_exact_division(i_arr, i_scale, scale_repr)
271+
# Fast path mirrors _encode: static proof that ``x // scale + offset`` stays in dtype.
272+
if _decode_fits_natively(arr.dtype, int(offset), int(scale)):
273+
return _decode_int_native(i_arr, i_offset, i_scale)
274+
return _decode_int_widened(i_arr, i_offset, i_scale)
275+
# Float path: division is well-defined; only guard against dtype promotion.
276+
f_arr = cast("npt.NDArray[np.floating[Any]]", arr)
277+
f_offset = cast("np.floating[Any]", offset)
278+
f_scale = cast("np.floating[Any]", scale)
279+
return _decode_float(f_arr, f_offset, f_scale)
280+
281+
20282
@dataclass(frozen=True)
21283
class ScaleOffset(ArrayArrayCodec):
22284
"""Scale-offset array-to-array codec.
23285
24-
Encodes values by subtracting an offset and multiplying by a scale factor.
25-
Decodes by dividing by the scale and adding the offset.
286+
Encodes values with ``out = (in - offset) * scale`` and decodes with
287+
``out = (in / scale) + offset``, using the input array's data type semantics.
288+
Intermediate or final values that are not representable in that dtype are reported
289+
as errors (integer overflow, unsigned underflow, non-exact integer division).
26290
27-
All arithmetic uses the input array's data type semantics (no implicit promotion).
291+
See https://github.com/zarr-developers/zarr-extensions/tree/main/codecs/scale_offset
292+
for the codec specification.
28293
29294
Parameters
30295
----------
@@ -78,6 +343,8 @@ def validate(
78343
f"scale_offset codec only supports integer and floating-point data types. "
79344
f"Got {dtype}."
80345
)
346+
if self.scale == 0:
347+
raise ValueError("scale_offset scale must be non-zero.")
81348
for name, value in [("offset", self.offset), ("scale", self.scale)]:
82349
try:
83350
dtype.from_json_scalar(value, zarr_format=3)
@@ -86,36 +353,25 @@ def validate(
86353
f"scale_offset {name} value {value!r} is not representable in dtype {native}."
87354
) from e
88355

89-
def _to_scalar(self, value: float | str, dtype: ZDType[TBaseDType, TBaseScalar]) -> TBaseScalar:
90-
"""Convert a JSON-form value to a numpy scalar using the given dtype."""
91-
return dtype.from_json_scalar(value, zarr_format=3)
92-
93356
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
94357
zdtype = chunk_spec.dtype
95-
fill = chunk_spec.fill_value
96-
offset = self._to_scalar(self.offset, zdtype)
97-
scale = self._to_scalar(self.scale, zdtype)
98-
new_fill = (zdtype.to_native_dtype().type(fill) - offset) * scale # type: ignore[operator]
99-
return replace(chunk_spec, fill_value=new_fill)
358+
fill = np.asarray(zdtype.cast_scalar(chunk_spec.fill_value))
359+
offset = cast("np.generic", zdtype.from_json_scalar(self.offset, zarr_format=3))
360+
scale = cast("np.generic", zdtype.from_json_scalar(self.scale, zarr_format=3))
361+
new_fill = _encode(fill, offset, scale)
362+
return replace(chunk_spec, fill_value=new_fill.reshape(()).item())
100363

101364
def _decode_sync(
102365
self,
103366
chunk_array: NDBuffer,
104367
chunk_spec: ArraySpec,
105368
) -> NDBuffer:
106-
arr = chunk_array.as_ndarray_like()
107-
offset = self._to_scalar(self.offset, chunk_spec.dtype)
108-
scale = self._to_scalar(self.scale, chunk_spec.dtype)
109-
if np.issubdtype(arr.dtype, np.integer):
110-
result = (arr // scale) + offset # type: ignore[operator]
111-
else:
112-
result = (arr / scale) + offset # type: ignore[operator]
113-
if result.dtype != arr.dtype:
114-
raise ValueError(
115-
f"scale_offset decode changed dtype from {arr.dtype} to {result.dtype}. "
116-
f"Arithmetic must preserve the data type."
117-
)
118-
return chunk_array.__class__.from_ndarray_like(result)
369+
arr = cast("np.ndarray[tuple[Any, ...], np.dtype[Any]]", chunk_array.as_ndarray_like())
370+
zdtype = chunk_spec.dtype
371+
offset = cast("np.generic", zdtype.from_json_scalar(self.offset, zarr_format=3))
372+
scale = cast("np.generic", zdtype.from_json_scalar(self.scale, zarr_format=3))
373+
result = _decode(arr, offset, scale, scale_repr=self.scale)
374+
return chunk_spec.prototype.nd_buffer.from_ndarray_like(result)
119375

120376
async def _decode_single(
121377
self,
@@ -129,16 +385,12 @@ def _encode_sync(
129385
chunk_array: NDBuffer,
130386
chunk_spec: ArraySpec,
131387
) -> NDBuffer | None:
132-
arr = chunk_array.as_ndarray_like()
133-
offset = self._to_scalar(self.offset, chunk_spec.dtype)
134-
scale = self._to_scalar(self.scale, chunk_spec.dtype)
135-
result = (arr - offset) * scale # type: ignore[operator]
136-
if result.dtype != arr.dtype:
137-
raise ValueError(
138-
f"scale_offset encode changed dtype from {arr.dtype} to {result.dtype}. "
139-
f"Arithmetic must preserve the data type."
140-
)
141-
return chunk_array.__class__.from_ndarray_like(result)
388+
arr = cast("np.ndarray[tuple[Any, ...], np.dtype[Any]]", chunk_array.as_ndarray_like())
389+
zdtype = chunk_spec.dtype
390+
offset = cast("np.generic", zdtype.from_json_scalar(self.offset, zarr_format=3))
391+
scale = cast("np.generic", zdtype.from_json_scalar(self.scale, zarr_format=3))
392+
result = _encode(arr, offset, scale)
393+
return chunk_spec.prototype.nd_buffer.from_ndarray_like(result)
142394

143395
async def _encode_single(
144396
self,

0 commit comments

Comments
 (0)