Skip to content

Commit 7f5f2b2

Browse files
committed
chore: simplify scalar map handling
1 parent 3ed9847 commit 7f5f2b2

2 files changed

Lines changed: 81 additions & 96 deletions

File tree

src/zarr/codecs/cast_value.py

Lines changed: 57 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010

1111
from __future__ import annotations
1212

13+
from collections.abc import Mapping
1314
from dataclasses import dataclass, replace
14-
from typing import TYPE_CHECKING, Literal, cast
15+
from typing import TYPE_CHECKING, Literal, TypedDict, cast
1516

1617
import numpy as np
1718

@@ -20,8 +21,7 @@
2021
from zarr.core.dtype import get_data_type_from_json
2122

2223
if TYPE_CHECKING:
23-
from collections.abc import Iterable, Mapping
24-
from typing import Any, NotRequired, Self, TypedDict
24+
from typing import NotRequired, Self
2525

2626
from zarr.core.array_spec import ArraySpec
2727
from zarr.core.buffer import NDBuffer
@@ -32,8 +32,6 @@ class ScalarMapJSON(TypedDict):
3232
encode: NotRequired[list[tuple[object, object]]]
3333
decode: NotRequired[list[tuple[object, object]]]
3434

35-
# Pre-parsed scalar map entry: (source_scalar, target_scalar)
36-
ScalarMapEntry = tuple[np.integer[Any] | np.floating[Any], np.integer[Any] | np.floating[Any]]
3735

3836
RoundingMode = Literal[
3937
"nearest-even",
@@ -46,88 +44,47 @@ class ScalarMapJSON(TypedDict):
4644
OutOfRangeMode = Literal["clamp", "wrap"]
4745

4846

49-
# ---------------------------------------------------------------------------
50-
# Scalar-map parsing helpers
51-
# ---------------------------------------------------------------------------
52-
47+
class ScalarMap(TypedDict):
48+
"""
49+
The normalized, in-memory form of a scalar map.
50+
"""
5351

54-
def _extract_raw_map(data: ScalarMapJSON | None, direction: str) -> dict[str, str] | None:
55-
"""Extract raw string mapping from scalar_map JSON for 'encode' or 'decode'."""
56-
if data is None:
57-
return None
58-
raw: dict[str, str] = {}
59-
pairs: list[tuple[object, object]] = data.get(direction, []) # type: ignore[assignment]
60-
for src, tgt in pairs:
61-
raw[str(src)] = str(tgt)
62-
return raw or None
52+
encode: NotRequired[Mapping[str | float | int, str | float | int]]
53+
decode: NotRequired[Mapping[str | float | int, str | float | int]]
6354

6455

65-
def _parse_map_entries(
66-
mapping: Mapping[str, str],
67-
src_dtype: ZDType[TBaseDType, TBaseScalar],
68-
tgt_dtype: ZDType[TBaseDType, TBaseScalar],
69-
) -> tuple[ScalarMapEntry, ...]:
70-
"""Pre-parse a scalar map dict into a tuple of (src, tgt) pairs.
56+
def parse_scalar_map(obj: ScalarMapJSON | ScalarMap) -> ScalarMap:
57+
"""
58+
Parse a scalar map into its normalized dict-of-dicts form.
7159
72-
Each entry's source value is deserialized using ``src_dtype`` and its target
73-
value using ``tgt_dtype``, preserving full precision for both data types.
60+
Accepts either the JSON form (lists of tuples) or an already-normalized form
61+
(dicts). For example, ``{"encode": [("NaN", 0)]}`` becomes
62+
``{"encode": {"NaN": 0}}``.
7463
"""
75-
entries: list[ScalarMapEntry] = [
76-
(
77-
src_dtype.from_json_scalar(src_str, zarr_format=3), # type: ignore[misc]
78-
tgt_dtype.from_json_scalar(tgt_str, zarr_format=3),
79-
)
80-
for src_str, tgt_str in mapping.items()
81-
]
82-
return tuple(entries)
64+
result: ScalarMap = {}
65+
for direction in ("encode", "decode"):
66+
if direction in obj:
67+
entries = obj[direction]
68+
if entries is not None:
69+
if isinstance(entries, Mapping):
70+
result[direction] = entries
71+
else:
72+
result[direction] = dict(entries) # type: ignore[arg-type]
73+
return result
8374

8475

8576
# ---------------------------------------------------------------------------
86-
# Backend: cast-value-rs (optional)
77+
# Backend: cast-value-rs
8778
# ---------------------------------------------------------------------------
8879

8980
try:
90-
from cast_value_rs import cast_array as _rs_cast_array
81+
from cast_value_rs import cast_array as cast_array_rs
9182

9283
_HAS_RUST_BACKEND = True
9384
except ModuleNotFoundError:
9485
_HAS_RUST_BACKEND = False
9586

9687

97-
def _dtype_to_str(dtype: np.dtype) -> str: # type: ignore[type-arg]
98-
return dtype.name
99-
100-
101-
def _convert_scalar_map(
102-
entries: Iterable[ScalarMapEntry] | None,
103-
) -> list[tuple[int | float, int | float]] | None:
104-
if entries is None:
105-
return None
106-
result: list[tuple[int | float, int | float]] = []
107-
for src, tgt in entries:
108-
src_py: int | float = int(src) if isinstance(src, np.integer) else float(src)
109-
tgt_py: int | float = int(tgt) if isinstance(tgt, np.integer) else float(tgt)
110-
result.append((src_py, tgt_py))
111-
return result
112-
113-
114-
def _cast_array_rs(
115-
arr: np.ndarray, # type: ignore[type-arg]
116-
*,
117-
target_dtype: np.dtype, # type: ignore[type-arg]
118-
rounding: RoundingMode,
119-
out_of_range: OutOfRangeMode | None,
120-
scalar_map_entries: Iterable[ScalarMapEntry] | None,
121-
) -> np.ndarray: # type: ignore[type-arg]
122-
return _rs_cast_array( # type: ignore[no-any-return]
123-
arr=arr,
124-
target_dtype=_dtype_to_str(target_dtype),
125-
rounding_mode=rounding,
126-
out_of_range_mode=out_of_range,
127-
scalar_map_entries=_convert_scalar_map(scalar_map_entries),
128-
)
129-
130-
13188
# ---------------------------------------------------------------------------
13289
# Codec
13390
# ---------------------------------------------------------------------------
@@ -166,15 +123,15 @@ class CastValue(ArrayArrayCodec):
166123
dtype: ZDType[TBaseDType, TBaseScalar]
167124
rounding: RoundingMode
168125
out_of_range: OutOfRangeMode | None
169-
scalar_map: ScalarMapJSON | None
126+
scalar_map: ScalarMap | None
170127

171128
def __init__(
172129
self,
173130
*,
174131
data_type: str | ZDType[TBaseDType, TBaseScalar],
175132
rounding: RoundingMode = "nearest-even",
176133
out_of_range: OutOfRangeMode | None = None,
177-
scalar_map: ScalarMapJSON | None = None,
134+
scalar_map: ScalarMapJSON | ScalarMap | None = None,
178135
) -> None:
179136
if isinstance(data_type, str):
180137
zdtype = get_data_type_from_json(data_type, zarr_format=3)
@@ -183,7 +140,11 @@ def __init__(
183140
object.__setattr__(self, "dtype", zdtype)
184141
object.__setattr__(self, "rounding", rounding)
185142
object.__setattr__(self, "out_of_range", out_of_range)
186-
object.__setattr__(self, "scalar_map", scalar_map)
143+
if scalar_map is not None:
144+
parsed = parse_scalar_map(scalar_map)
145+
else:
146+
parsed = None
147+
object.__setattr__(self, "scalar_map", parsed)
187148

188149
@classmethod
189150
def from_dict(cls, data: dict[str, JSON]) -> Self:
@@ -199,7 +160,11 @@ def to_dict(self) -> dict[str, JSON]:
199160
if self.out_of_range is not None:
200161
config["out_of_range"] = self.out_of_range
201162
if self.scalar_map is not None:
202-
config["scalar_map"] = cast("JSON", self.scalar_map)
163+
json_map: dict[str, list[tuple[object, object]]] = {}
164+
for direction in ("encode", "decode"):
165+
if direction in self.scalar_map:
166+
json_map[direction] = [(k, v) for k, v in self.scalar_map[direction].items()]
167+
config["scalar_map"] = cast("JSON", json_map)
203168
return {"name": "cast_value", "configuration": config}
204169

205170
def validate(
@@ -225,21 +190,32 @@ def _do_cast(
225190
arr: np.ndarray, # type: ignore[type-arg]
226191
*,
227192
target_dtype: np.dtype, # type: ignore[type-arg]
228-
scalar_map_entries: Iterable[ScalarMapEntry] | None,
193+
scalar_map: Mapping[str | float | int, str | float | int] | None,
229194
) -> np.ndarray: # type: ignore[type-arg]
230195
if not _HAS_RUST_BACKEND:
231196
raise ImportError(
232197
"The cast_value codec requires the 'cast-value-rs' package. "
233198
"Install it with: pip install cast-value-rs"
234199
)
235-
return _cast_array_rs(
200+
scalar_map_entries: dict[float, float] | None = None
201+
if scalar_map is not None:
202+
scalar_map_entries = {float(k): float(v) for k, v in scalar_map.items()}
203+
return cast_array_rs( # type: ignore[no-any-return]
236204
arr,
237205
target_dtype=target_dtype,
238-
rounding=self.rounding,
239-
out_of_range=self.out_of_range,
206+
rounding_mode=self.rounding,
207+
out_of_range_mode=self.out_of_range,
240208
scalar_map_entries=scalar_map_entries,
241209
)
242210

211+
def _get_scalar_map(
212+
self, direction: str
213+
) -> Mapping[str | float | int, str | float | int] | None:
214+
"""Extract the encode or decode mapping from scalar_map, or None."""
215+
if self.scalar_map is None:
216+
return None
217+
return self.scalar_map.get(direction) # type: ignore[return-value]
218+
243219
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
244220
"""
245221
Update the fill value of the output spec by applying casting procedure.
@@ -251,13 +227,8 @@ def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
251227
fill = chunk_spec.fill_value
252228
fill_arr = np.array([fill], dtype=source_native)
253229

254-
encode_raw = _extract_raw_map(self.scalar_map, "encode")
255-
encode_entries = (
256-
_parse_map_entries(encode_raw, chunk_spec.dtype, self.dtype) if encode_raw else None
257-
)
258-
259230
new_fill_arr = self._do_cast(
260-
fill_arr, target_dtype=target_native, scalar_map_entries=encode_entries
231+
fill_arr, target_dtype=target_native, scalar_map=self._get_scalar_map("encode")
261232
)
262233
new_fill = target_native.type(new_fill_arr[0])
263234

@@ -271,13 +242,8 @@ def _encode_sync(
271242
arr = chunk_array.as_ndarray_like()
272243
target_native = self.dtype.to_native_dtype()
273244

274-
encode_raw = _extract_raw_map(self.scalar_map, "encode")
275-
encode_entries = (
276-
_parse_map_entries(encode_raw, _chunk_spec.dtype, self.dtype) if encode_raw else None
277-
)
278-
279245
result = self._do_cast(
280-
np.asarray(arr), target_dtype=target_native, scalar_map_entries=encode_entries
246+
np.asarray(arr), target_dtype=target_native, scalar_map=self._get_scalar_map("encode")
281247
)
282248
return chunk_array.__class__.from_ndarray_like(result)
283249

@@ -296,13 +262,8 @@ def _decode_sync(
296262
arr = chunk_array.as_ndarray_like()
297263
target_native = chunk_spec.dtype.to_native_dtype()
298264

299-
decode_raw = _extract_raw_map(self.scalar_map, "decode")
300-
decode_entries = (
301-
_parse_map_entries(decode_raw, self.dtype, chunk_spec.dtype) if decode_raw else None
302-
)
303-
304265
result = self._do_cast(
305-
np.asarray(arr), target_dtype=target_native, scalar_map_entries=decode_entries
266+
np.asarray(arr), target_dtype=target_native, scalar_map=self._get_scalar_map("decode")
306267
)
307268
return chunk_array.__class__.from_ndarray_like(result)
308269

tests/test_codecs/test_cast_value.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,3 +279,27 @@ def test_combined_with_scale_offset() -> None:
279279
arr[:] = data
280280
result = arr[:]
281281
np.testing.assert_array_almost_equal(result, data, decimal=1) # type: ignore[arg-type]
282+
283+
284+
@pytest.mark.parametrize(
285+
"case",
286+
[
287+
Expect(
288+
input={"encode": [("NaN", 0)]},
289+
expected={"encode": {"NaN": 0}},
290+
),
291+
Expect(
292+
input={"encode": [("NaN", 0)], "decode": [(0, "NaN")]},
293+
expected={"encode": {"NaN": 0}, "decode": {0: "NaN"}},
294+
),
295+
Expect(
296+
input={"encode": {"NaN": 0}},
297+
expected={"encode": {"NaN": 0}},
298+
),
299+
],
300+
ids=["encode-only", "both-directions", "already-normalized"],
301+
)
302+
def test_parse_scalar_map(case: Expect[Any, Any]) -> None:
303+
from zarr.codecs.cast_value import parse_scalar_map
304+
305+
assert parse_scalar_map(case.input) == case.expected

0 commit comments

Comments
 (0)