11from __future__ import annotations
22
33from dataclasses import dataclass , replace
4- from typing import TYPE_CHECKING
4+ from typing import TYPE_CHECKING , Any , cast
55
66import numpy as np
7+ import numpy .typing as npt
78
89from zarr .abc .codec import ArrayArrayCodec
910from zarr .core .common import JSON , parse_named_configuration
1718 from zarr .core .metadata .v3 import ChunkGridMetadata
1819
1920
21+ _WIDE_INT = np .dtype (np .int64 )
22+
23+
24+ def _encode_fits_natively (dtype : np .dtype [Any ], offset : int , scale : int ) -> bool :
25+ """Static range proof: is ``(x - offset) * scale`` always in range for every ``x`` in dtype?
26+
27+ Uses Python ints (unbounded) to avoid overflow in the proof itself.
28+ """
29+ info = np .iinfo (dtype )
30+ d_lo = int (info .min ) - offset
31+ d_hi = int (info .max ) - offset
32+ # Taking min/max of both products handles negative scale without a sign branch.
33+ products = (d_lo * scale , d_hi * scale )
34+ lo , hi = min (products ), max (products )
35+ return info .min <= lo and hi <= info .max
36+
37+
38+ def _decode_fits_natively (dtype : np .dtype [Any ], offset : int , scale : int ) -> bool :
39+ """Static range proof for decode: is ``x // scale + offset`` always in range?"""
40+ info = np .iinfo (dtype )
41+ # x // scale is bounded by the extremes of x / scale (integer division stays within that range)
42+ if scale > 0 :
43+ q_lo , q_hi = int (info .min ) // scale , int (info .max ) // scale
44+ else :
45+ q_lo , q_hi = int (info .max ) // scale , int (info .min ) // scale
46+ lo , hi = q_lo + offset , q_hi + offset
47+ return info .min <= lo and hi <= info .max
48+
49+
50+ def _check_int_range (
51+ values : npt .NDArray [np .integer [Any ]], target : np .dtype [np .integer [Any ]]
52+ ) -> None :
53+ """Raise if any value is outside the representable range of ``target``.
54+
55+ Uses a single min/max pass instead of two ``np.any`` passes.
56+ """
57+ info = np .iinfo (target )
58+ lo , hi = values .min (), values .max ()
59+ if lo < info .min or hi > info .max :
60+ raise ValueError (
61+ f"scale_offset produced a value outside the range of dtype { target } "
62+ f"[{ info .min } , { info .max } ]."
63+ )
64+
65+
66+ def _check_exact_division (
67+ arr : npt .NDArray [np .integer [Any ]], scale : np .integer [Any ], scale_repr : object
68+ ) -> None :
69+ """Raise ValueError if ``arr`` has any element not exactly divisible by ``scale``."""
70+ if np .any (arr % scale ):
71+ raise ValueError (
72+ f"scale_offset decode produced a non-zero remainder when dividing by "
73+ f"scale={ scale_repr !r} ; result is not exactly representable in dtype { arr .dtype } ."
74+ )
75+
76+
77+ def _encode_int_native (
78+ arr : npt .NDArray [np .integer [Any ]], offset : np .integer [Any ], scale : np .integer [Any ]
79+ ) -> npt .NDArray [np .integer [Any ]]:
80+ """Compute ``(arr - offset) * scale`` directly in ``arr.dtype``.
81+
82+ This is the fast path; it exists only as a separate function to make the contract with
83+ ``_encode_fits_natively`` explicit: the caller must have already proved that no ``x`` in
84+ ``arr.dtype``'s range can overflow, so we can skip widening and range-checking entirely.
85+ Using it without that proof would silently wrap on overflow.
86+ """
87+ return cast ("npt.NDArray[np.integer[Any]]" , (arr - offset ) * scale )
88+
89+
90+ def _encode_int_widened (
91+ arr : npt .NDArray [np .integer [Any ]], offset : np .integer [Any ], scale : np .integer [Any ]
92+ ) -> npt .NDArray [np .integer [Any ]]:
93+ """Overflow-checked integer encode for int8..int64 and uint8..uint32.
94+
95+ Exists because numpy integer arithmetic silently wraps on overflow, which the spec
96+ forbids. We widen to int64, perform the arithmetic there (int64 holds the product of any
97+ two values from these dtypes), range-check against the target dtype, then cast back.
98+ uint64 cannot use this path because its range exceeds int64 — see ``_encode_uint64``.
99+ """
100+ wide_arr = arr .astype (_WIDE_INT , copy = False )
101+ result = (wide_arr - _WIDE_INT .type (offset )) * _WIDE_INT .type (scale )
102+ _check_int_range (result , arr .dtype )
103+ return result .astype (arr .dtype , copy = False )
104+
105+
106+ def _encode_float (
107+ arr : npt .NDArray [np .floating [Any ]], offset : np .floating [Any ], scale : np .floating [Any ]
108+ ) -> npt .NDArray [np .floating [Any ]]:
109+ """Encode float arrays in-dtype, guarding only against silent promotion.
110+
111+ Float arithmetic doesn't need widening — float64 is already the widest supported dtype,
112+ and ``inf``/``nan`` from overflow are representable IEEE 754 values, so no range check is
113+ required by the spec. The one thing that can still go wrong is numpy promoting the
114+ result to a wider float dtype (e.g. float32 * float64 scalar -> float64), which would
115+ violate the spec's "arithmetic semantics of the input array's data type" clause.
116+ """
117+ result = cast ("npt.NDArray[np.floating[Any]]" , (arr - offset ) * scale )
118+ if result .dtype != arr .dtype :
119+ raise ValueError (
120+ f"scale_offset changed dtype from { arr .dtype } to { result .dtype } . "
121+ f"Arithmetic must preserve the data type."
122+ )
123+ return result
124+
125+
126+ def _check_py_int_range (
127+ result : np .ndarray [tuple [Any , ...], np .dtype [Any ]],
128+ target : np .dtype [np .unsignedinteger [Any ]],
129+ ) -> None :
130+ """Range-check an ``object``-dtype ndarray holding Python ints against ``target``'s iinfo.
131+
132+ Exists as a uint64-specific counterpart to ``_check_int_range``. That one compares numpy
133+ integers against ``iinfo``; here the values are unbounded Python ints produced by
134+ ``_encode_uint64`` / ``_decode_uint64``, so we rely on Python's arbitrary-precision
135+ comparison to detect values outside the target dtype's range.
136+ """
137+ info = np .iinfo (target )
138+ # np.min/np.max on an object array returns a Python int (which compares correctly with iinfo).
139+ # Works uniformly for 0-d arrays where .flat iteration is awkward.
140+ lo = np .min (result )
141+ hi = np .max (result )
142+ if lo < int (info .min ) or hi > int (info .max ):
143+ raise ValueError (
144+ f"scale_offset produced a value outside the range of dtype { target } "
145+ f"[{ info .min } , { info .max } ]."
146+ )
147+
148+
149+ def _encode_uint64 (
150+ arr : npt .NDArray [np .unsignedinteger [Any ]], offset : int , scale : int
151+ ) -> npt .NDArray [np .unsignedinteger [Any ]]:
152+ """Encode uint64 via Python-int arithmetic in an ``object``-dtype array.
153+
154+ Exists because uint64's range [0, 2**64) exceeds int64, so the int64 widening used by
155+ ``_encode_int_widened`` would itself overflow. Python ints are unbounded, so computing
156+ via ``object`` dtype is correct by construction. The trade-off is speed: object-dtype
157+ arithmetic is interpreted per element and is roughly 10x slower than ufunc paths.
158+ """
159+ obj = arr .astype (object , copy = False )
160+ # np.asarray restores ndarray-ness in the 0-d/scalar edge case.
161+ result = np .asarray ((obj - offset ) * scale , dtype = object )
162+ _check_py_int_range (result , arr .dtype )
163+ return cast ("npt.NDArray[np.unsignedinteger[Any]]" , result .astype (arr .dtype , copy = False ))
164+
165+
166+ def _decode_uint64 (
167+ arr : npt .NDArray [np .unsignedinteger [Any ]], offset : int , scale : int
168+ ) -> npt .NDArray [np .unsignedinteger [Any ]]:
169+ """Decode uint64 via Python-int arithmetic. See ``_encode_uint64`` for why."""
170+ obj = arr .astype (object , copy = False )
171+ result = np .asarray ((obj // scale ) + offset , dtype = object )
172+ _check_py_int_range (result , arr .dtype )
173+ return cast ("npt.NDArray[np.unsignedinteger[Any]]" , result .astype (arr .dtype , copy = False ))
174+
175+
176+ def _decode_int_native (
177+ arr : npt .NDArray [np .integer [Any ]], offset : np .integer [Any ], scale : np .integer [Any ]
178+ ) -> npt .NDArray [np .integer [Any ]]:
179+ """Compute ``arr // scale + offset`` directly in ``arr.dtype``.
180+
181+ Fast-path counterpart to ``_encode_int_native``; same contract. Caller must have proved
182+ via ``_decode_fits_natively`` that the result can't overflow. Divisibility is checked
183+ upstream in ``_decode`` before this is called, so ``//`` is exact here.
184+ """
185+ return cast ("npt.NDArray[np.integer[Any]]" , (arr // scale ) + offset )
186+
187+
188+ def _decode_int_widened (
189+ arr : npt .NDArray [np .integer [Any ]], offset : np .integer [Any ], scale : np .integer [Any ]
190+ ) -> npt .NDArray [np .integer [Any ]]:
191+ """Overflow-checked integer decode for int8..int64 and uint8..uint32.
192+
193+ Counterpart to ``_encode_int_widened``. Widens to int64 so the addition of ``offset``
194+ after division can't silently wrap, then range-checks against the target dtype.
195+ """
196+ wide_arr = arr .astype (_WIDE_INT , copy = False )
197+ result = (wide_arr // _WIDE_INT .type (scale )) + _WIDE_INT .type (offset )
198+ _check_int_range (result , arr .dtype )
199+ return result .astype (arr .dtype , copy = False )
200+
201+
202+ def _decode_float (
203+ arr : npt .NDArray [np .floating [Any ]], offset : np .floating [Any ], scale : np .floating [Any ]
204+ ) -> npt .NDArray [np .floating [Any ]]:
205+ """Decode float arrays in-dtype, guarding only against silent promotion.
206+
207+ Counterpart to ``_encode_float``; same reasoning. ``arr / scale`` is true division and
208+ always well-defined for floats (including ``0/0 = nan`` and ``x/0 = ±inf``), so no range
209+ or exactness check is needed.
210+ """
211+ result = cast ("npt.NDArray[np.floating[Any]]" , (arr / scale ) + offset )
212+ if result .dtype != arr .dtype :
213+ raise ValueError (
214+ f"scale_offset changed dtype from { arr .dtype } to { result .dtype } . "
215+ f"Arithmetic must preserve the data type."
216+ )
217+ return result
218+
219+
220+ def _encode (
221+ arr : np .ndarray [tuple [Any , ...], np .dtype [Any ]],
222+ offset : np .generic ,
223+ scale : np .generic ,
224+ ) -> np .ndarray [tuple [Any , ...], np .dtype [Any ]]:
225+ """Compute ``(arr - offset) * scale`` without silent overflow, returning ``arr.dtype``."""
226+ # uint64 is split out first because its full range (up to 2**64-1) doesn't fit in int64,
227+ # so the widening strategy used for every other integer dtype would itself overflow.
228+ if arr .dtype == np .uint64 :
229+ u_arr = cast ("npt.NDArray[np.unsignedinteger[Any]]" , arr )
230+ return _encode_uint64 (u_arr , int (offset ), int (scale ))
231+ if np .issubdtype (arr .dtype , np .integer ):
232+ i_arr = cast ("npt.NDArray[np.integer[Any]]" , arr )
233+ i_offset = cast ("np.integer[Any]" , offset )
234+ i_scale = cast ("np.integer[Any]" , scale )
235+ # Fast path: if a static proof shows no ``x`` in the dtype's range can overflow,
236+ # skip the int64 widening and run the arithmetic directly in the input dtype.
237+ if _encode_fits_natively (arr .dtype , int (offset ), int (scale )):
238+ return _encode_int_native (i_arr , i_offset , i_scale )
239+ return _encode_int_widened (i_arr , i_offset , i_scale )
240+ # Float path: arithmetic stays in-dtype (no widening); only guard against numpy
241+ # silently promoting a narrower float to a wider one via scalar type mismatch.
242+ f_arr = cast ("npt.NDArray[np.floating[Any]]" , arr )
243+ f_offset = cast ("np.floating[Any]" , offset )
244+ f_scale = cast ("np.floating[Any]" , scale )
245+ return _encode_float (f_arr , f_offset , f_scale )
246+
247+
248+ def _decode (
249+ arr : np .ndarray [tuple [Any , ...], np .dtype [Any ]],
250+ offset : np .generic ,
251+ scale : np .generic ,
252+ * ,
253+ scale_repr : object ,
254+ ) -> np .ndarray [tuple [Any , ...], np .dtype [Any ]]:
255+ """Compute ``arr / scale + offset`` without silent overflow, returning ``arr.dtype``."""
256+ # uint64: same reasoning as _encode — its range exceeds int64, so the Python-int path is the
257+ # only correct option. Exactness check runs first so non-divisible inputs fail before the
258+ # slower object-dtype arithmetic.
259+ if arr .dtype == np .uint64 :
260+ u_arr = cast ("npt.NDArray[np.unsignedinteger[Any]]" , arr )
261+ _check_exact_division (u_arr , cast ("np.integer[Any]" , scale ), scale_repr )
262+ return _decode_uint64 (u_arr , int (offset ), int (scale ))
263+ if np .issubdtype (arr .dtype , np .integer ):
264+ i_arr = cast ("npt.NDArray[np.integer[Any]]" , arr )
265+ i_offset = cast ("np.integer[Any]" , offset )
266+ i_scale = cast ("np.integer[Any]" , scale )
267+ # The spec requires decode to use true division and error if the result isn't
268+ # representable. For integers that means the remainder must be zero; if any element
269+ # isn't exactly divisible we fail here rather than silently truncating via //.
270+ _check_exact_division (i_arr , i_scale , scale_repr )
271+ # Fast path mirrors _encode: static proof that ``x // scale + offset`` stays in dtype.
272+ if _decode_fits_natively (arr .dtype , int (offset ), int (scale )):
273+ return _decode_int_native (i_arr , i_offset , i_scale )
274+ return _decode_int_widened (i_arr , i_offset , i_scale )
275+ # Float path: division is well-defined; only guard against dtype promotion.
276+ f_arr = cast ("npt.NDArray[np.floating[Any]]" , arr )
277+ f_offset = cast ("np.floating[Any]" , offset )
278+ f_scale = cast ("np.floating[Any]" , scale )
279+ return _decode_float (f_arr , f_offset , f_scale )
280+
281+
20282@dataclass (frozen = True )
21283class ScaleOffset (ArrayArrayCodec ):
22284 """Scale-offset array-to-array codec.
23285
24- Encodes values by subtracting an offset and multiplying by a scale factor.
25- Decodes by dividing by the scale and adding the offset.
286+ Encodes values with ``out = (in - offset) * scale`` and decodes with
287+ ``out = (in / scale) + offset``, using the input array's data type semantics.
288+ Intermediate or final values that are not representable in that dtype are reported
289+ as errors (integer overflow, unsigned underflow, non-exact integer division).
26290
27- All arithmetic uses the input array's data type semantics (no implicit promotion).
291+ See https://github.com/zarr-developers/zarr-extensions/tree/main/codecs/scale_offset
292+ for the codec specification.
28293
29294 Parameters
30295 ----------
@@ -78,6 +343,8 @@ def validate(
78343 f"scale_offset codec only supports integer and floating-point data types. "
79344 f"Got { dtype } ."
80345 )
346+ if self .scale == 0 :
347+ raise ValueError ("scale_offset scale must be non-zero." )
81348 for name , value in [("offset" , self .offset ), ("scale" , self .scale )]:
82349 try :
83350 dtype .from_json_scalar (value , zarr_format = 3 )
@@ -86,36 +353,25 @@ def validate(
86353 f"scale_offset { name } value { value !r} is not representable in dtype { native } ."
87354 ) from e
88355
89- def _to_scalar (self , value : float | str , dtype : ZDType [TBaseDType , TBaseScalar ]) -> TBaseScalar :
90- """Convert a JSON-form value to a numpy scalar using the given dtype."""
91- return dtype .from_json_scalar (value , zarr_format = 3 )
92-
93356 def resolve_metadata (self , chunk_spec : ArraySpec ) -> ArraySpec :
94357 zdtype = chunk_spec .dtype
95- fill = chunk_spec .fill_value
96- offset = self . _to_scalar (self .offset , zdtype )
97- scale = self . _to_scalar (self .scale , zdtype )
98- new_fill = ( zdtype . to_native_dtype (). type ( fill ) - offset ) * scale # type: ignore[operator]
99- return replace (chunk_spec , fill_value = new_fill )
358+ fill = np . asarray ( zdtype . cast_scalar ( chunk_spec .fill_value ))
359+ offset = cast ( "np.generic" , zdtype . from_json_scalar (self .offset , zarr_format = 3 ) )
360+ scale = cast ( "np.generic" , zdtype . from_json_scalar (self .scale , zarr_format = 3 ) )
361+ new_fill = _encode ( fill , offset , scale )
362+ return replace (chunk_spec , fill_value = new_fill . reshape (()). item () )
100363
101364 def _decode_sync (
102365 self ,
103366 chunk_array : NDBuffer ,
104367 chunk_spec : ArraySpec ,
105368 ) -> NDBuffer :
106- arr = chunk_array .as_ndarray_like ()
107- offset = self ._to_scalar (self .offset , chunk_spec .dtype )
108- scale = self ._to_scalar (self .scale , chunk_spec .dtype )
109- if np .issubdtype (arr .dtype , np .integer ):
110- result = (arr // scale ) + offset # type: ignore[operator]
111- else :
112- result = (arr / scale ) + offset # type: ignore[operator]
113- if result .dtype != arr .dtype :
114- raise ValueError (
115- f"scale_offset decode changed dtype from { arr .dtype } to { result .dtype } . "
116- f"Arithmetic must preserve the data type."
117- )
118- return chunk_array .__class__ .from_ndarray_like (result )
369+ arr = cast ("np.ndarray[tuple[Any, ...], np.dtype[Any]]" , chunk_array .as_ndarray_like ())
370+ zdtype = chunk_spec .dtype
371+ offset = cast ("np.generic" , zdtype .from_json_scalar (self .offset , zarr_format = 3 ))
372+ scale = cast ("np.generic" , zdtype .from_json_scalar (self .scale , zarr_format = 3 ))
373+ result = _decode (arr , offset , scale , scale_repr = self .scale )
374+ return chunk_spec .prototype .nd_buffer .from_ndarray_like (result )
119375
120376 async def _decode_single (
121377 self ,
@@ -129,16 +385,12 @@ def _encode_sync(
129385 chunk_array : NDBuffer ,
130386 chunk_spec : ArraySpec ,
131387 ) -> NDBuffer | None :
132- arr = chunk_array .as_ndarray_like ()
133- offset = self ._to_scalar (self .offset , chunk_spec .dtype )
134- scale = self ._to_scalar (self .scale , chunk_spec .dtype )
135- result = (arr - offset ) * scale # type: ignore[operator]
136- if result .dtype != arr .dtype :
137- raise ValueError (
138- f"scale_offset encode changed dtype from { arr .dtype } to { result .dtype } . "
139- f"Arithmetic must preserve the data type."
140- )
141- return chunk_array .__class__ .from_ndarray_like (result )
388+ arr = cast ("np.ndarray[tuple[Any, ...], np.dtype[Any]]" , chunk_array .as_ndarray_like ())
389+ zdtype = chunk_spec .dtype
390+ offset = cast ("np.generic" , zdtype .from_json_scalar (self .offset , zarr_format = 3 ))
391+ scale = cast ("np.generic" , zdtype .from_json_scalar (self .scale , zarr_format = 3 ))
392+ result = _encode (arr , offset , scale )
393+ return chunk_spec .prototype .nd_buffer .from_ndarray_like (result )
142394
143395 async def _encode_single (
144396 self ,
0 commit comments