diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 1116940a4cc..66bd461157e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -17,6 +17,8 @@ New Features - Added ``inherit='all_coords'`` option to :py:meth:`DataTree.to_dataset` to inherit all parent coordinates, not just indexed ones (:issue:`10812`, :pull:`11230`). By `Alfonso Ladino `_. +- Added complex dtype support to FillValueCoder for the Zarr backend. (:pull:`11151`) + By `Max Jones `_. Breaking Changes ~~~~~~~~~~~~~~~~ diff --git a/properties/test_encode_decode.py b/properties/test_encode_decode.py index 87bbfdba933..bce2fc215c6 100644 --- a/properties/test_encode_decode.py +++ b/properties/test_encode_decode.py @@ -15,6 +15,7 @@ import hypothesis.extra.numpy as npst import numpy as np from hypothesis import given +from hypothesis import strategies as st import xarray as xr from xarray.coding.times import _parse_iso8601 @@ -48,6 +49,22 @@ def test_CFScaleOffset_coder_roundtrip(original) -> None: xr.testing.assert_identical(original, roundtripped) +@given( + real=st.floats(allow_nan=True, allow_infinity=True), + imag=st.floats(allow_nan=True, allow_infinity=True), + dtype=st.sampled_from([np.complex64, np.complex128]), +) +def test_FillValueCoder_complex_roundtrip(real, imag, dtype) -> None: + from xarray.backends.zarr import FillValueCoder + + value = dtype(complex(real, imag)) + encoded = FillValueCoder.encode(value, np.dtype(dtype)) + decoded = FillValueCoder.decode(encoded, np.dtype(dtype)) + np.testing.assert_equal( + np.array(decoded, dtype=dtype), np.array(value, dtype=dtype) + ) + + @given(dt=datetimes()) def test_iso8601_decode(dt): iso = dt.isoformat() diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 6681673025c..d9279dc2de9 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -122,40 +122,89 @@ class FillValueCoder: """ @classmethod - def encode(cls, value: int | float | str | bytes, dtype: np.dtype[Any]) -> Any: + def encode( + cls, value: int | float | complex | str | bytes, dtype: np.dtype[Any] + ) -> Any: if dtype.kind == "S": # byte string, this implies that 'value' must also be `bytes` dtype. - assert isinstance(value, bytes) + if not isinstance(value, bytes): + raise TypeError( + f"Failed to encode fill_value: expected bytes for dtype {dtype}, got {type(value).__name__}" + ) return base64.standard_b64encode(value).decode() elif dtype.kind == "b": # boolean return bool(value) elif dtype.kind in "iu": - # todo: do we want to check for decimals? + if not isinstance(value, int | float | np.integer | np.floating): + raise TypeError( + f"Failed to encode fill_value: expected int or float for dtype {dtype}, got {type(value).__name__}" + ) return int(value) elif dtype.kind == "f": + if not isinstance(value, int | float | np.integer | np.floating): + raise TypeError( + f"Failed to encode fill_value: expected int or float for dtype {dtype}, got {type(value).__name__}" + ) return base64.standard_b64encode(struct.pack(" None: assert actual3 == expected3 +@requires_zarr +@pytest.mark.parametrize("dtype", [complex, np.complex64, np.complex128]) +def test_fill_value_coder_complex(dtype) -> None: + """Test that FillValueCoder round-trips complex fill values.""" + from xarray.backends.zarr import FillValueCoder + + for value in [dtype(1 + 2j), dtype(-3.5 + 4.5j), dtype(complex("nan+nanj"))]: + encoded = FillValueCoder.encode(value, np.dtype(dtype)) + decoded = FillValueCoder.decode(encoded, np.dtype(dtype)) + np.testing.assert_equal(np.array(decoded, dtype=dtype), np.array(value)) + + +@requires_zarr +@pytest.mark.parametrize( + "value,dtype", + [ + (np.float32(np.inf), np.float32), + (np.float32(-np.inf), np.float32), + (np.float64(np.inf), np.float64), + (np.float64(-np.inf), np.float64), + (np.float32(np.nan), np.float32), + (np.float64(np.nan), np.float64), + ], +) +def test_fill_value_coder_inf_nan(value, dtype) -> None: + """Test that FillValueCoder round-trips inf and nan fill values.""" + from xarray.backends.zarr import FillValueCoder + + encoded = FillValueCoder.encode(value, np.dtype(dtype)) + decoded = FillValueCoder.decode(encoded, np.dtype(dtype)) + np.testing.assert_equal( + np.array(decoded, dtype=dtype), np.array(value, dtype=dtype) + ) + + @requires_zarr def test_extract_zarr_variable_encoding() -> None: var = xr.Variable("x", [1, 2])