Skip to content

Commit ad374b5

Browse files
authored
cast_value data type validation was checking the source data type instead of the target data type in the (#3938)
"can we use an out of range mode" check. This means a float dtype source and an int dtype target would raise an error, which is incorrect. the fix ensures that we check the _target_ dtype.
1 parent f8c0c5d commit ad374b5

2 files changed

Lines changed: 56 additions & 11 deletions

File tree

src/zarr/codecs/cast_value.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,31 @@ class ScalarMap(TypedDict, total=False):
5454

5555

5656
# see https://github.com/zarr-developers/zarr-extensions/tree/main/codecs/cast_value
57-
PERMITTED_DATA_TYPE_NAMES: Final[set[str]] = {
57+
CAST_VALUE_INT_DTYPES: Final[set[str]] = {
58+
# signed
5859
"int2",
5960
"int4",
6061
"int8",
6162
"int16",
6263
"int32",
6364
"int64",
64-
"int64uint2",
65+
# unsigned
66+
"uint2",
6567
"uint4",
6668
"uint8",
6769
"uint16",
6870
"uint32",
6971
"uint64",
70-
"uint64float4_e2m1fn",
72+
}
73+
"""Integer dtype identifiers permitted as the source or target of `cast_value`.
74+
75+
Membership in this set drives the `out_of_range="wrap"` rule, which the
76+
spec restricts to integral targets that use two's-complement representation
77+
for modular arithmetic.
78+
"""
79+
80+
CAST_VALUE_FLOAT_DTYPES: Final[set[str]] = {
81+
"float4_e2m1fn",
7182
"float6_e2m3fn",
7283
"float6_e3m2fn",
7384
"float8_e3m4",
@@ -82,6 +93,10 @@ class ScalarMap(TypedDict, total=False):
8293
"float32",
8394
"float64",
8495
}
96+
"""Floating-point dtype identifiers permitted as the source or target of `cast_value`."""
97+
98+
PERMITTED_DATA_TYPE_NAMES: Final[set[str]] = CAST_VALUE_INT_DTYPES | CAST_VALUE_FLOAT_DTYPES
99+
"""All dtype identifiers the `cast_value` codec is defined for."""
85100

86101

87102
def parse_scalar_map(obj: ScalarMapJSON | ScalarMap) -> ScalarMap:
@@ -240,15 +255,22 @@ def validate(
240255
dtype: ZDType[TBaseDType, TBaseScalar],
241256
chunk_grid: ChunkGridMetadata,
242257
) -> None:
243-
target_name = dtype.to_json(zarr_format=3)
244-
if target_name not in PERMITTED_DATA_TYPE_NAMES:
258+
# `dtype` is the source (the array's dtype); `self.dtype` is the
259+
# cast target. The spec requires both to be permitted, and rules
260+
# like `out_of_range="wrap"` apply to the target.
261+
source_name = dtype.to_json(zarr_format=3)
262+
target_name = self.dtype.to_json(zarr_format=3)
263+
for role, name in (("source", source_name), ("target", target_name)):
264+
if name not in PERMITTED_DATA_TYPE_NAMES:
265+
raise ValueError(
266+
f"The cast_value codec only supports integer and floating-point data types. "
267+
f"Got {role} dtype {name}."
268+
)
269+
if self.out_of_range == "wrap" and target_name not in CAST_VALUE_INT_DTYPES:
245270
raise ValueError(
246-
f"The cast_value codec only supports integer and floating-point data types. "
247-
f"Got dtype {target_name}."
271+
f"out_of_range='wrap' is only valid for integer target types. "
272+
f"Got target dtype {target_name}."
248273
)
249-
target_native = dtype.to_native_dtype()
250-
if self.out_of_range == "wrap" and not np.issubdtype(target_native, np.integer):
251-
raise ValueError("out_of_range='wrap' is only valid for integer target types.")
252274

253275
if self.scalar_map is not None:
254276
self._validate_scalar_map(dtype, self.dtype)

tests/test_codecs/test_cast_value.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def test_construction_rejects_invalid_target_dtype() -> None:
139139
exception_cls=ValueError,
140140
),
141141
ExpectErr(
142-
input={"dtype": "float32", "target": "int32", "out_of_range": "wrap"},
142+
input={"dtype": "int32", "target": "float64", "out_of_range": "wrap"},
143143
msg="only valid for integer",
144144
exception_cls=ValueError,
145145
),
@@ -165,6 +165,29 @@ def test_validation_rejects_invalid(case: ExpectErr[dict[str, Any]]) -> None:
165165
)
166166

167167

168+
@pytest.mark.parametrize(
169+
("source_dtype", "target_dtype"),
170+
[
171+
("float16", "int8"),
172+
("float32", "int32"),
173+
("float64", "int64"),
174+
("int32", "uint8"),
175+
],
176+
)
177+
def test_validation_accepts_wrap_with_integer_target(source_dtype: str, target_dtype: str) -> None:
178+
"""Regression for #3936: `out_of_range="wrap"` is permitted when the
179+
cast TARGET (not the source array dtype) is an integer type."""
180+
zarr.create_array(
181+
store={},
182+
shape=(1,),
183+
dtype=source_dtype,
184+
chunks=(1,),
185+
filters=[CastValue(data_type=target_dtype, out_of_range="wrap")],
186+
compressors=None,
187+
fill_value=0,
188+
)
189+
190+
168191
def test_zero_itemsize_raises() -> None:
169192
"""Variable-length dtypes (itemsize=0) are rejected by compute_encoded_size."""
170193
from zarr.core.array_spec import ArrayConfig, ArraySpec

0 commit comments

Comments
 (0)