Skip to content

Commit ee15c9e

Browse files
committed
fix: use explicit list of data type names
1 parent def02b6 commit ee15c9e

File tree

2 files changed

+53
-13
lines changed

2 files changed

+53
-13
lines changed

src/zarr/codecs/cast_value.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from collections.abc import Mapping
1414
from dataclasses import dataclass, replace
15-
from typing import TYPE_CHECKING, Literal, TypedDict, cast
15+
from typing import TYPE_CHECKING, Final, Literal, TypedDict, cast
1616

1717
import numpy as np
1818

@@ -53,6 +53,36 @@ class ScalarMap(TypedDict, total=False):
5353
decode: Mapping[str | float | int, str | float | int]
5454

5555

56+
PERMITTED_DATA_TYPE_NAMES: Final[set[str]] = {
57+
"int2",
58+
"int4",
59+
"int8",
60+
"int16",
61+
"int32",
62+
"int64",
63+
"int64uint2",
64+
"uint4",
65+
"uint8",
66+
"uint16",
67+
"uint32",
68+
"uint64",
69+
"uint64float4_e2m1fn",
70+
"float6_e2m3fn",
71+
"float6_e3m2fn",
72+
"float8_e3m4",
73+
"float8_e4m3",
74+
"float8_e4m3b11fnuz",
75+
"float8_e4m3fnuz",
76+
"float8_e5m2",
77+
"float8_e5m2fnuz",
78+
"float8_e8m0fnu",
79+
"bfloat16",
80+
"float16",
81+
"float32",
82+
"float64",
83+
}
84+
85+
5686
def parse_scalar_map(obj: ScalarMapJSON | ScalarMap) -> ScalarMap:
5787
"""
5888
Parse a scalar map into its normalized dict-of-dicts form.
@@ -151,6 +181,12 @@ def __init__(
151181
zdtype = get_data_type_from_json(data_type, zarr_format=3)
152182
else:
153183
zdtype = data_type
184+
if zdtype.to_json(zarr_format=3) not in PERMITTED_DATA_TYPE_NAMES:
185+
raise ValueError(
186+
f"Invalid target data type {data_type!r}. "
187+
f"cast_value codec only supports integer and floating-point data types. "
188+
f"Got {zdtype}."
189+
)
154190
object.__setattr__(self, "dtype", zdtype)
155191
object.__setattr__(self, "rounding", rounding)
156192
object.__setattr__(self, "out_of_range", out_of_range)
@@ -188,14 +224,13 @@ def validate(
188224
dtype: ZDType[TBaseDType, TBaseScalar],
189225
chunk_grid: ChunkGridMetadata,
190226
) -> None:
191-
source_native = dtype.to_native_dtype()
192-
target_native = self.dtype.to_native_dtype()
193-
for label, dt in [("source", source_native), ("target", target_native)]:
194-
if not np.issubdtype(dt, np.integer) and not np.issubdtype(dt, np.floating):
195-
raise ValueError(
196-
f"The cast_value codec only supports integer and floating-point data types. "
197-
f"Got {label} dtype {dt}."
198-
)
227+
target_name = dtype.to_json(zarr_format=3)
228+
if target_name not in PERMITTED_DATA_TYPE_NAMES:
229+
raise ValueError(
230+
f"The cast_value codec only supports integer and floating-point data types. "
231+
f"Got dtype {target_name}."
232+
)
233+
target_native = dtype.to_native_dtype()
199234
if self.out_of_range == "wrap" and not np.issubdtype(target_native, np.integer):
200235
raise ValueError("out_of_range='wrap' is only valid for integer target types.")
201236

@@ -318,6 +353,12 @@ async def _decode_single(
318353
return self._decode_sync(chunk_data, chunk_spec)
319354

320355
def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> int:
356+
dtype_name = chunk_spec.dtype.to_json(zarr_format=3)
357+
if dtype_name not in PERMITTED_DATA_TYPE_NAMES:
358+
raise ValueError(
359+
"cast_value codec only supports fixed-size integer and floating-point data types. "
360+
f"Got source dtype: {chunk_spec.dtype}."
361+
)
321362
source_itemsize = chunk_spec.dtype.to_native_dtype().itemsize
322363
target_itemsize = self.dtype.to_native_dtype().itemsize
323364
if source_itemsize == 0 or target_itemsize == 0:

tests/test_codecs/test_cast_value.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import pytest
77

8+
import zarr
89
from tests.test_codecs.conftest import Expect, ExpectErr
910
from zarr.codecs.cast_value import CastValue
1011

@@ -119,7 +120,7 @@ def test_serialization_roundtrip(codec: CastValue) -> None:
119120
exception_cls=ValueError,
120121
),
121122
ExpectErr(
122-
input={"dtype": "int32", "target": "float32", "out_of_range": "wrap"},
123+
input={"dtype": "float32", "target": "int32", "out_of_range": "wrap"},
123124
msg="only valid for integer",
124125
exception_cls=ValueError,
125126
),
@@ -128,8 +129,6 @@ def test_serialization_roundtrip(codec: CastValue) -> None:
128129
)
129130
def test_validation_rejects_invalid(case: ExpectErr[dict[str, Any]]) -> None:
130131
"""Invalid dtype or out_of_range combinations are rejected at array creation."""
131-
import zarr
132-
133132
with pytest.raises(case.exception_cls, match=case.msg):
134133
zarr.create_array(
135134
store={},
@@ -161,7 +160,7 @@ def test_zero_itemsize_raises() -> None:
161160
config=ArrayConfig(order="C", write_empty_chunks=True),
162161
prototype=default_buffer_prototype(),
163162
)
164-
with pytest.raises(ValueError, match="fixed-size data types"):
163+
with pytest.raises(ValueError, match="fixed-size integer and floating-point data types"):
165164
codec.compute_encoded_size(100, spec)
166165

167166

0 commit comments

Comments
 (0)