1010
1111from __future__ import annotations
1212
13+ from collections .abc import Mapping
1314from dataclasses import dataclass , replace
14- from typing import TYPE_CHECKING , Literal , cast
15+ from typing import TYPE_CHECKING , Literal , TypedDict , cast
1516
1617import numpy as np
1718
2021from zarr .core .dtype import get_data_type_from_json
2122
2223if TYPE_CHECKING :
23- from collections .abc import Iterable , Mapping
24- from typing import Any , NotRequired , Self , TypedDict
24+ from typing import NotRequired , Self
2525
2626 from zarr .core .array_spec import ArraySpec
2727 from zarr .core .buffer import NDBuffer
@@ -32,8 +32,6 @@ class ScalarMapJSON(TypedDict):
3232 encode : NotRequired [list [tuple [object , object ]]]
3333 decode : NotRequired [list [tuple [object , object ]]]
3434
35- # Pre-parsed scalar map entry: (source_scalar, target_scalar)
36- ScalarMapEntry = tuple [np .integer [Any ] | np .floating [Any ], np .integer [Any ] | np .floating [Any ]]
3735
3836RoundingMode = Literal [
3937 "nearest-even" ,
@@ -46,88 +44,47 @@ class ScalarMapJSON(TypedDict):
4644OutOfRangeMode = Literal ["clamp" , "wrap" ]
4745
4846
49- # ---------------------------------------------------------------------------
50- # Scalar-map parsing helpers
51- # ---------------------------------------------------------------------------
52-
47+ class ScalarMap ( TypedDict ):
48+ """
49+ The normalized, in-memory form of a scalar map.
50+ """
5351
54- def _extract_raw_map (data : ScalarMapJSON | None , direction : str ) -> dict [str , str ] | None :
55- """Extract raw string mapping from scalar_map JSON for 'encode' or 'decode'."""
56- if data is None :
57- return None
58- raw : dict [str , str ] = {}
59- pairs : list [tuple [object , object ]] = data .get (direction , []) # type: ignore[assignment]
60- for src , tgt in pairs :
61- raw [str (src )] = str (tgt )
62- return raw or None
52+ encode : NotRequired [Mapping [str | float | int , str | float | int ]]
53+ decode : NotRequired [Mapping [str | float | int , str | float | int ]]
6354
6455
65- def _parse_map_entries (
66- mapping : Mapping [str , str ],
67- src_dtype : ZDType [TBaseDType , TBaseScalar ],
68- tgt_dtype : ZDType [TBaseDType , TBaseScalar ],
69- ) -> tuple [ScalarMapEntry , ...]:
70- """Pre-parse a scalar map dict into a tuple of (src, tgt) pairs.
56+ def parse_scalar_map (obj : ScalarMapJSON | ScalarMap ) -> ScalarMap :
57+ """
58+ Parse a scalar map into its normalized dict-of-dicts form.
7159
72- Each entry's source value is deserialized using ``src_dtype`` and its target
73- value using ``tgt_dtype``, preserving full precision for both data types.
60+ Accepts either the JSON form (lists of tuples) or an already-normalized form
61+ (dicts). For example, ``{"encode": [("NaN", 0)]}`` becomes
62+ ``{"encode": {"NaN": 0}}``.
7463 """
75- entries : list [ScalarMapEntry ] = [
76- (
77- src_dtype .from_json_scalar (src_str , zarr_format = 3 ), # type: ignore[misc]
78- tgt_dtype .from_json_scalar (tgt_str , zarr_format = 3 ),
79- )
80- for src_str , tgt_str in mapping .items ()
81- ]
82- return tuple (entries )
64+ result : ScalarMap = {}
65+ for direction in ("encode" , "decode" ):
66+ if direction in obj :
67+ entries = obj [direction ]
68+ if entries is not None :
69+ if isinstance (entries , Mapping ):
70+ result [direction ] = entries
71+ else :
72+ result [direction ] = dict (entries ) # type: ignore[arg-type]
73+ return result
8374
8475
8576# ---------------------------------------------------------------------------
86- # Backend: cast-value-rs (optional)
77+ # Backend: cast-value-rs
8778# ---------------------------------------------------------------------------
8879
8980try :
90- from cast_value_rs import cast_array as _rs_cast_array
81+ from cast_value_rs import cast_array as cast_array_rs
9182
9283 _HAS_RUST_BACKEND = True
9384except ModuleNotFoundError :
9485 _HAS_RUST_BACKEND = False
9586
9687
97- def _dtype_to_str (dtype : np .dtype ) -> str : # type: ignore[type-arg]
98- return dtype .name
99-
100-
101- def _convert_scalar_map (
102- entries : Iterable [ScalarMapEntry ] | None ,
103- ) -> list [tuple [int | float , int | float ]] | None :
104- if entries is None :
105- return None
106- result : list [tuple [int | float , int | float ]] = []
107- for src , tgt in entries :
108- src_py : int | float = int (src ) if isinstance (src , np .integer ) else float (src )
109- tgt_py : int | float = int (tgt ) if isinstance (tgt , np .integer ) else float (tgt )
110- result .append ((src_py , tgt_py ))
111- return result
112-
113-
114- def _cast_array_rs (
115- arr : np .ndarray , # type: ignore[type-arg]
116- * ,
117- target_dtype : np .dtype , # type: ignore[type-arg]
118- rounding : RoundingMode ,
119- out_of_range : OutOfRangeMode | None ,
120- scalar_map_entries : Iterable [ScalarMapEntry ] | None ,
121- ) -> np .ndarray : # type: ignore[type-arg]
122- return _rs_cast_array ( # type: ignore[no-any-return]
123- arr = arr ,
124- target_dtype = _dtype_to_str (target_dtype ),
125- rounding_mode = rounding ,
126- out_of_range_mode = out_of_range ,
127- scalar_map_entries = _convert_scalar_map (scalar_map_entries ),
128- )
129-
130-
13188# ---------------------------------------------------------------------------
13289# Codec
13390# ---------------------------------------------------------------------------
@@ -166,15 +123,15 @@ class CastValue(ArrayArrayCodec):
166123 dtype : ZDType [TBaseDType , TBaseScalar ]
167124 rounding : RoundingMode
168125 out_of_range : OutOfRangeMode | None
169- scalar_map : ScalarMapJSON | None
126+ scalar_map : ScalarMap | None
170127
171128 def __init__ (
172129 self ,
173130 * ,
174131 data_type : str | ZDType [TBaseDType , TBaseScalar ],
175132 rounding : RoundingMode = "nearest-even" ,
176133 out_of_range : OutOfRangeMode | None = None ,
177- scalar_map : ScalarMapJSON | None = None ,
134+ scalar_map : ScalarMapJSON | ScalarMap | None = None ,
178135 ) -> None :
179136 if isinstance (data_type , str ):
180137 zdtype = get_data_type_from_json (data_type , zarr_format = 3 )
@@ -183,7 +140,11 @@ def __init__(
183140 object .__setattr__ (self , "dtype" , zdtype )
184141 object .__setattr__ (self , "rounding" , rounding )
185142 object .__setattr__ (self , "out_of_range" , out_of_range )
186- object .__setattr__ (self , "scalar_map" , scalar_map )
143+ if scalar_map is not None :
144+ parsed = parse_scalar_map (scalar_map )
145+ else :
146+ parsed = None
147+ object .__setattr__ (self , "scalar_map" , parsed )
187148
188149 @classmethod
189150 def from_dict (cls , data : dict [str , JSON ]) -> Self :
@@ -199,7 +160,11 @@ def to_dict(self) -> dict[str, JSON]:
199160 if self .out_of_range is not None :
200161 config ["out_of_range" ] = self .out_of_range
201162 if self .scalar_map is not None :
202- config ["scalar_map" ] = cast ("JSON" , self .scalar_map )
163+ json_map : dict [str , list [tuple [object , object ]]] = {}
164+ for direction in ("encode" , "decode" ):
165+ if direction in self .scalar_map :
166+ json_map [direction ] = [(k , v ) for k , v in self .scalar_map [direction ].items ()]
167+ config ["scalar_map" ] = cast ("JSON" , json_map )
203168 return {"name" : "cast_value" , "configuration" : config }
204169
205170 def validate (
@@ -225,21 +190,32 @@ def _do_cast(
225190 arr : np .ndarray , # type: ignore[type-arg]
226191 * ,
227192 target_dtype : np .dtype , # type: ignore[type-arg]
228- scalar_map_entries : Iterable [ ScalarMapEntry ] | None ,
193+ scalar_map : Mapping [ str | float | int , str | float | int ] | None ,
229194 ) -> np .ndarray : # type: ignore[type-arg]
230195 if not _HAS_RUST_BACKEND :
231196 raise ImportError (
232197 "The cast_value codec requires the 'cast-value-rs' package. "
233198 "Install it with: pip install cast-value-rs"
234199 )
235- return _cast_array_rs (
200+ scalar_map_entries : dict [float , float ] | None = None
201+ if scalar_map is not None :
202+ scalar_map_entries = {float (k ): float (v ) for k , v in scalar_map .items ()}
203+ return cast_array_rs ( # type: ignore[no-any-return]
236204 arr ,
237205 target_dtype = target_dtype ,
238- rounding = self .rounding ,
239- out_of_range = self .out_of_range ,
206+ rounding_mode = self .rounding ,
207+ out_of_range_mode = self .out_of_range ,
240208 scalar_map_entries = scalar_map_entries ,
241209 )
242210
211+ def _get_scalar_map (
212+ self , direction : str
213+ ) -> Mapping [str | float | int , str | float | int ] | None :
214+ """Extract the encode or decode mapping from scalar_map, or None."""
215+ if self .scalar_map is None :
216+ return None
217+ return self .scalar_map .get (direction ) # type: ignore[return-value]
218+
243219 def resolve_metadata (self , chunk_spec : ArraySpec ) -> ArraySpec :
244220 """
245221 Update the fill value of the output spec by applying casting procedure.
@@ -251,13 +227,8 @@ def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
251227 fill = chunk_spec .fill_value
252228 fill_arr = np .array ([fill ], dtype = source_native )
253229
254- encode_raw = _extract_raw_map (self .scalar_map , "encode" )
255- encode_entries = (
256- _parse_map_entries (encode_raw , chunk_spec .dtype , self .dtype ) if encode_raw else None
257- )
258-
259230 new_fill_arr = self ._do_cast (
260- fill_arr , target_dtype = target_native , scalar_map_entries = encode_entries
231+ fill_arr , target_dtype = target_native , scalar_map = self . _get_scalar_map ( "encode" )
261232 )
262233 new_fill = target_native .type (new_fill_arr [0 ])
263234
@@ -271,13 +242,8 @@ def _encode_sync(
271242 arr = chunk_array .as_ndarray_like ()
272243 target_native = self .dtype .to_native_dtype ()
273244
274- encode_raw = _extract_raw_map (self .scalar_map , "encode" )
275- encode_entries = (
276- _parse_map_entries (encode_raw , _chunk_spec .dtype , self .dtype ) if encode_raw else None
277- )
278-
279245 result = self ._do_cast (
280- np .asarray (arr ), target_dtype = target_native , scalar_map_entries = encode_entries
246+ np .asarray (arr ), target_dtype = target_native , scalar_map = self . _get_scalar_map ( "encode" )
281247 )
282248 return chunk_array .__class__ .from_ndarray_like (result )
283249
@@ -296,13 +262,8 @@ def _decode_sync(
296262 arr = chunk_array .as_ndarray_like ()
297263 target_native = chunk_spec .dtype .to_native_dtype ()
298264
299- decode_raw = _extract_raw_map (self .scalar_map , "decode" )
300- decode_entries = (
301- _parse_map_entries (decode_raw , self .dtype , chunk_spec .dtype ) if decode_raw else None
302- )
303-
304265 result = self ._do_cast (
305- np .asarray (arr ), target_dtype = target_native , scalar_map_entries = decode_entries
266+ np .asarray (arr ), target_dtype = target_native , scalar_map = self . _get_scalar_map ( "decode" )
306267 )
307268 return chunk_array .__class__ .from_ndarray_like (result )
308269
0 commit comments