Skip to content

Commit 0e01a21

Browse files
committed
chore: preserve JSON encoding of scale and offset parameters
1 parent b35d5a3 commit 0e01a21

File tree

2 files changed

+76
-18
lines changed

2 files changed

+76
-18
lines changed

src/zarr/codecs/scale_offset.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@ class ScaleOffset(ArrayArrayCodec):
3636

3737
is_fixed_size = True
3838

39-
offset: int | float
40-
scale: int | float
39+
offset: int | float | str
40+
scale: int | float | str
4141

4242
def __init__(self, *, offset: object = 0, scale: object = 1) -> None:
43-
if not isinstance(offset, int | float):
44-
raise TypeError(f"offset must be a number, got {type(offset).__name__}")
45-
if not isinstance(scale, int | float):
46-
raise TypeError(f"scale must be a number, got {type(scale).__name__}")
43+
if not isinstance(offset, int | float | str):
44+
raise TypeError(f"offset must be a number or string, got {type(offset).__name__}")
45+
if not isinstance(scale, int | float | str):
46+
raise TypeError(f"scale must be a number or string, got {type(scale).__name__}")
4747
object.__setattr__(self, "offset", offset)
4848
object.__setattr__(self, "scale", scale)
4949

@@ -78,25 +78,38 @@ def validate(
7878
f"scale_offset codec only supports integer and floating-point data types. "
7979
f"Got {dtype}."
8080
)
81+
for name, value in [("offset", self.offset), ("scale", self.scale)]:
82+
try:
83+
dtype.from_json_scalar(value, zarr_format=3)
84+
except (TypeError, ValueError) as e:
85+
raise ValueError(
86+
f"scale_offset {name} value {value!r} is not representable in dtype {native}."
87+
) from e
88+
89+
def _to_scalar(self, value: float | str, dtype: ZDType[TBaseDType, TBaseScalar]) -> TBaseScalar:
90+
"""Convert a JSON-form value to a numpy scalar using the given dtype."""
91+
return dtype.from_json_scalar(value, zarr_format=3)
8192

8293
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
83-
native_dtype = chunk_spec.dtype.to_native_dtype()
94+
zdtype = chunk_spec.dtype
8495
fill = chunk_spec.fill_value
85-
new_fill = (native_dtype.type(fill) - native_dtype.type(self.offset)) * native_dtype.type( # type: ignore[operator]
86-
self.scale
87-
)
96+
offset = self._to_scalar(self.offset, zdtype)
97+
scale = self._to_scalar(self.scale, zdtype)
98+
new_fill = (zdtype.to_native_dtype().type(fill) - offset) * scale # type: ignore[operator]
8899
return replace(chunk_spec, fill_value=new_fill)
89100

90101
def _decode_sync(
91102
self,
92103
chunk_array: NDBuffer,
93-
_chunk_spec: ArraySpec,
104+
chunk_spec: ArraySpec,
94105
) -> NDBuffer:
95106
arr = chunk_array.as_ndarray_like()
107+
offset = self._to_scalar(self.offset, chunk_spec.dtype)
108+
scale = self._to_scalar(self.scale, chunk_spec.dtype)
96109
if np.issubdtype(arr.dtype, np.integer):
97-
result = (arr // arr.dtype.type(self.scale)) + arr.dtype.type(self.offset)
110+
result = (arr // scale) + offset # type: ignore[operator]
98111
else:
99-
result = (arr / arr.dtype.type(self.scale)) + arr.dtype.type(self.offset)
112+
result = (arr / scale) + offset # type: ignore[operator]
100113
if result.dtype != arr.dtype:
101114
raise ValueError(
102115
f"scale_offset decode changed dtype from {arr.dtype} to {result.dtype}. "
@@ -114,10 +127,12 @@ async def _decode_single(
114127
def _encode_sync(
115128
self,
116129
chunk_array: NDBuffer,
117-
_chunk_spec: ArraySpec,
130+
chunk_spec: ArraySpec,
118131
) -> NDBuffer | None:
119132
arr = chunk_array.as_ndarray_like()
120-
result = (arr - arr.dtype.type(self.offset)) * arr.dtype.type(self.scale)
133+
offset = self._to_scalar(self.offset, chunk_spec.dtype)
134+
scale = self._to_scalar(self.scale, chunk_spec.dtype)
135+
result = (arr - offset) * scale # type: ignore[operator]
121136
if result.dtype != arr.dtype:
122137
raise ValueError(
123138
f"scale_offset encode changed dtype from {arr.dtype} to {result.dtype}. "

tests/test_codecs/test_scale_offset.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,16 @@ def test_serialization_roundtrip() -> None:
8989
@pytest.mark.parametrize(
9090
"case",
9191
[
92-
ExpectErr(input={"offset": "bad"}, msg="offset must be a number", exception_cls=TypeError),
93-
ExpectErr(input={"scale": [1, 2]}, msg="scale must be a number", exception_cls=TypeError),
92+
ExpectErr(
93+
input={"offset": [1, 2]},
94+
msg="offset must be a number or string",
95+
exception_cls=TypeError,
96+
),
97+
ExpectErr(
98+
input={"scale": [1, 2]}, msg="scale must be a number or string", exception_cls=TypeError
99+
),
94100
],
95-
ids=["string-offset", "list-scale"],
101+
ids=["list-offset", "list-scale"],
96102
)
97103
def test_construction_rejects_non_numeric(case: ExpectErr[dict[str, Any]]) -> None:
98104
"""Non-numeric offset or scale is rejected at construction time."""
@@ -201,6 +207,43 @@ def test_rejects_complex_dtype() -> None:
201207
)
202208

203209

210+
@pytest.mark.parametrize(
211+
"case",
212+
[
213+
ExpectErr(
214+
input={"dtype": "int32", "offset": 1.5, "scale": 1},
215+
msg="offset value 1.5 is not representable",
216+
exception_cls=ValueError,
217+
),
218+
ExpectErr(
219+
input={"dtype": "int32", "offset": 0, "scale": 0.5},
220+
msg="scale value 0.5 is not representable",
221+
exception_cls=ValueError,
222+
),
223+
ExpectErr(
224+
input={"dtype": "int16", "offset": "NaN", "scale": 1},
225+
msg="offset value 'NaN' is not representable",
226+
exception_cls=ValueError,
227+
),
228+
],
229+
ids=["float-offset-for-int", "float-scale-for-int", "nan-offset-for-int"],
230+
)
231+
def test_rejects_unrepresentable_scale_offset(case: ExpectErr[dict[str, Any]]) -> None:
232+
"""Scale/offset values that can't be represented in the array dtype are rejected."""
233+
import zarr
234+
235+
with pytest.raises(case.exception_cls, match=case.msg):
236+
zarr.create_array(
237+
store={},
238+
shape=(10,),
239+
dtype=case.input["dtype"],
240+
chunks=(10,),
241+
filters=[ScaleOffset(offset=case.input["offset"], scale=case.input["scale"])],
242+
compressors=None,
243+
fill_value=0,
244+
)
245+
246+
204247
def test_dtype_preservation() -> None:
205248
"""Integer scale/offset arithmetic preserves the array dtype via floor division."""
206249
import zarr

0 commit comments

Comments
 (0)