|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import numpy as np |
3 | 4 | import pytest |
4 | 5 |
|
5 | 6 | import zarr |
| 7 | +from zarr.codecs import BytesCodec, CastValue |
6 | 8 | from zarr.core.array import _get_chunk_spec |
7 | 9 | from zarr.core.buffer.core import default_buffer_prototype |
8 | 10 | from zarr.core.indexing import BasicIndexer |
@@ -70,3 +72,51 @@ async def test_read_returns_get_results( |
70 | 72 | assert len(results) == len(expected_statuses) |
71 | 73 | for result, expected_status in zip(results, expected_statuses, strict=True): |
72 | 74 | assert result["status"] == expected_status |
| 75 | + |
| 76 | + |
| 77 | +try: |
| 78 | + import cast_value_rs # noqa: F401 |
| 79 | + |
| 80 | + _HAS_CAST_VALUE_RS = True |
| 81 | +except ModuleNotFoundError: |
| 82 | + _HAS_CAST_VALUE_RS = False |
| 83 | + |
| 84 | +requires_cast_value_rs = pytest.mark.skipif( |
| 85 | + not _HAS_CAST_VALUE_RS, reason="cast-value-rs not installed" |
| 86 | +) |
| 87 | + |
| 88 | + |
| 89 | +@requires_cast_value_rs |
| 90 | +@pytest.mark.parametrize( |
| 91 | + ("source_dtype", "target_dtype"), |
| 92 | + [ |
| 93 | + # Source is single-byte (no endianness); target is multi-byte (has endianness). |
| 94 | + # Without the fix, BytesCodec.evolve_from_array_spec sees the source dtype, |
| 95 | + # strips its `endian` to None, and then chokes when the chunk_spec dtype |
| 96 | + # gets transformed to the multi-byte target before bytes-decoding. |
| 97 | + ("int8", "int16"), |
| 98 | + ("uint8", "int32"), |
| 99 | + ("int8", "float32"), |
| 100 | + # Source is multi-byte; target is single-byte (the reverse direction also |
| 101 | + # exercises the spec-threading logic). |
| 102 | + ("int16", "int8"), |
| 103 | + ], |
| 104 | +) |
| 105 | +def test_codec_pipeline_threads_dtype_through_evolve(source_dtype: str, target_dtype: str) -> None: |
| 106 | + """Regression for #3937: each codec must be evolved against the spec it |
| 107 | + will see at runtime, not the original array spec. cast_value transforms |
| 108 | + the dtype between AA codecs and the array->bytes serializer.""" |
| 109 | + arr = zarr.create_array( |
| 110 | + store={}, |
| 111 | + shape=(4,), |
| 112 | + chunks=(4,), |
| 113 | + dtype=source_dtype, |
| 114 | + fill_value=0, |
| 115 | + filters=[CastValue(data_type=target_dtype)], |
| 116 | + serializer=BytesCodec(endian="little"), |
| 117 | + compressors=[], |
| 118 | + zarr_format=3, |
| 119 | + overwrite=True, |
| 120 | + ) |
| 121 | + arr[:] = np.asarray([0, 1, 2, 3], dtype=source_dtype) |
| 122 | + np.testing.assert_array_equal(arr[:], np.asarray([0, 1, 2, 3], dtype=source_dtype)) |
0 commit comments