|
12 | 12 |
|
13 | 13 | from collections.abc import Mapping |
14 | 14 | from dataclasses import dataclass, replace |
15 | | -from typing import TYPE_CHECKING, Literal, TypedDict, cast |
| 15 | +from typing import TYPE_CHECKING, Final, Literal, TypedDict, cast |
16 | 16 |
|
17 | 17 | import numpy as np |
18 | 18 |
|
@@ -53,6 +53,36 @@ class ScalarMap(TypedDict, total=False): |
53 | 53 | decode: Mapping[str | float | int, str | float | int] |
54 | 54 |
|
55 | 55 |
|
| 56 | +PERMITTED_DATA_TYPE_NAMES: Final[set[str]] = { |
| 57 | + "int2", |
| 58 | + "int4", |
| 59 | + "int8", |
| 60 | + "int16", |
| 61 | + "int32", |
| 62 | + "int64", |
| 63 | + "int64uint2", |
| 64 | + "uint4", |
| 65 | + "uint8", |
| 66 | + "uint16", |
| 67 | + "uint32", |
| 68 | + "uint64", |
| 69 | + "uint64float4_e2m1fn", |
| 70 | + "float6_e2m3fn", |
| 71 | + "float6_e3m2fn", |
| 72 | + "float8_e3m4", |
| 73 | + "float8_e4m3", |
| 74 | + "float8_e4m3b11fnuz", |
| 75 | + "float8_e4m3fnuz", |
| 76 | + "float8_e5m2", |
| 77 | + "float8_e5m2fnuz", |
| 78 | + "float8_e8m0fnu", |
| 79 | + "bfloat16", |
| 80 | + "float16", |
| 81 | + "float32", |
| 82 | + "float64", |
| 83 | +} |
| 84 | + |
| 85 | + |
56 | 86 | def parse_scalar_map(obj: ScalarMapJSON | ScalarMap) -> ScalarMap: |
57 | 87 | """ |
58 | 88 | Parse a scalar map into its normalized dict-of-dicts form. |
@@ -151,6 +181,12 @@ def __init__( |
151 | 181 | zdtype = get_data_type_from_json(data_type, zarr_format=3) |
152 | 182 | else: |
153 | 183 | zdtype = data_type |
| 184 | + if zdtype.to_json(zarr_format=3) not in PERMITTED_DATA_TYPE_NAMES: |
| 185 | + raise ValueError( |
| 186 | + f"Invalid target data type {data_type!r}. " |
| 187 | + f"cast_value codec only supports integer and floating-point data types. " |
| 188 | + f"Got {zdtype}." |
| 189 | + ) |
154 | 190 | object.__setattr__(self, "dtype", zdtype) |
155 | 191 | object.__setattr__(self, "rounding", rounding) |
156 | 192 | object.__setattr__(self, "out_of_range", out_of_range) |
@@ -188,14 +224,13 @@ def validate( |
188 | 224 | dtype: ZDType[TBaseDType, TBaseScalar], |
189 | 225 | chunk_grid: ChunkGridMetadata, |
190 | 226 | ) -> None: |
191 | | - source_native = dtype.to_native_dtype() |
192 | | - target_native = self.dtype.to_native_dtype() |
193 | | - for label, dt in [("source", source_native), ("target", target_native)]: |
194 | | - if not np.issubdtype(dt, np.integer) and not np.issubdtype(dt, np.floating): |
195 | | - raise ValueError( |
196 | | - f"The cast_value codec only supports integer and floating-point data types. " |
197 | | - f"Got {label} dtype {dt}." |
198 | | - ) |
| 227 | + target_name = dtype.to_json(zarr_format=3) |
| 228 | + if target_name not in PERMITTED_DATA_TYPE_NAMES: |
| 229 | + raise ValueError( |
| 230 | + f"The cast_value codec only supports integer and floating-point data types. " |
| 231 | + f"Got dtype {target_name}." |
| 232 | + ) |
| 233 | + target_native = dtype.to_native_dtype() |
199 | 234 | if self.out_of_range == "wrap" and not np.issubdtype(target_native, np.integer): |
200 | 235 | raise ValueError("out_of_range='wrap' is only valid for integer target types.") |
201 | 236 |
|
@@ -318,6 +353,12 @@ async def _decode_single( |
318 | 353 | return self._decode_sync(chunk_data, chunk_spec) |
319 | 354 |
|
320 | 355 | def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> int: |
| 356 | + dtype_name = chunk_spec.dtype.to_json(zarr_format=3) |
| 357 | + if dtype_name not in PERMITTED_DATA_TYPE_NAMES: |
| 358 | + raise ValueError( |
| 359 | + "cast_value codec only supports fixed-size integer and floating-point data types. " |
| 360 | + f"Got source dtype: {chunk_spec.dtype}." |
| 361 | + ) |
321 | 362 | source_itemsize = chunk_spec.dtype.to_native_dtype().itemsize |
322 | 363 | target_itemsize = self.dtype.to_native_dtype().itemsize |
323 | 364 | if source_itemsize == 0 or target_itemsize == 0: |
|
0 commit comments