|
| 1 | +"""Unit tests for chunking module helper functions.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from typing import TYPE_CHECKING |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import pytest |
| 9 | +import xarray as xr |
| 10 | +import zarr |
| 11 | + |
| 12 | +from mdio.api.io import open_mdio |
| 13 | +from mdio.api.io import to_mdio |
| 14 | +from mdio.builder.schemas.chunk_grid import RegularChunkGrid |
| 15 | +from mdio.builder.schemas.chunk_grid import RegularChunkShape |
| 16 | +from mdio.builder.schemas.compressors import Blosc |
| 17 | +from mdio.constants import ZarrFormat |
| 18 | +from mdio.transpose_writers.chunking import _normalize_chunk_grid |
| 19 | +from mdio.transpose_writers.chunking import _normalize_compressor |
| 20 | +from mdio.transpose_writers.chunking import _normalize_new_variable |
| 21 | +from mdio.transpose_writers.chunking import _remove_fillvalue_attrs |
| 22 | +from mdio.transpose_writers.chunking import _validate_inputs |
| 23 | + |
| 24 | +if TYPE_CHECKING: |
| 25 | + from pathlib import Path |
| 26 | + |
| 27 | + |
| 28 | +class TestRemoveFillvalueAttrs: |
| 29 | + """Tests for _remove_fillvalue_attrs helper function.""" |
| 30 | + |
| 31 | + @pytest.mark.parametrize("zarr_format", [ZarrFormat.V2, ZarrFormat.V3]) |
| 32 | + def test_remove_fillvalue_after_mdio_serialization(self, tmp_path: Path, zarr_format: ZarrFormat) -> None: |
| 33 | + """Test that _FillValue is removed after MDIO serialization in both Zarr v2 and v3.""" |
| 34 | + # Create dataset with NaN values (will add _FillValue on serialization) |
| 35 | + data = np.array([[1.0, 2.0, np.nan], [3.0, np.nan, 4.0]], dtype=np.float32) |
| 36 | + ds = xr.Dataset( |
| 37 | + {"var1": (["x", "y"], data, {"units": "meters"})}, |
| 38 | + coords={"x": (["x"], [0, 1], {"axis": "X"})}, |
| 39 | + ) |
| 40 | + |
| 41 | + # Write and read back with MDIO |
| 42 | + mdio_path = tmp_path / f"test_{zarr_format}.mdio" |
| 43 | + with zarr.config.set(default_zarr_format=zarr_format): |
| 44 | + to_mdio(ds, mdio_path, mode="w") |
| 45 | + ds_read = open_mdio(mdio_path) |
| 46 | + |
| 47 | + # Apply function and verify _FillValue is removed everywhere |
| 48 | + _remove_fillvalue_attrs(ds_read) |
| 49 | + |
| 50 | + for var_name in list(ds_read.data_vars) + list(ds_read.coords): |
| 51 | + assert "_FillValue" not in ds_read[var_name].attrs |
| 52 | + |
| 53 | + # Verify other attributes preserved |
| 54 | + assert ds_read["var1"].attrs["units"] == "meters" |
| 55 | + assert ds_read["x"].attrs["axis"] == "X" |
| 56 | + |
| 57 | + |
| 58 | +class TestValidateInputs: |
| 59 | + """Tests for _validate_inputs helper function.""" |
| 60 | + |
| 61 | + @pytest.mark.parametrize( |
| 62 | + ("new_variable", "chunk_grid", "compressor", "should_pass"), |
| 63 | + [ |
| 64 | + ("var1", "grid", "comp", True), |
| 65 | + (["var1", "var2"], "grid", "comp", True), |
| 66 | + ("var1", "grid", None, True), |
| 67 | + (123, "grid", "comp", False), |
| 68 | + ([], "grid", "comp", False), |
| 69 | + ("var1", [], "comp", False), |
| 70 | + ("var1", "grid", [], False), |
| 71 | + ], |
| 72 | + ) |
| 73 | + def test_validation( |
| 74 | + self, |
| 75 | + new_variable: str | list | int, |
| 76 | + chunk_grid: str | list, |
| 77 | + compressor: str | list | None, |
| 78 | + should_pass: bool, |
| 79 | + ) -> None: |
| 80 | + """Test input validation with various combinations.""" |
| 81 | + grid = RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=(10, 10))) |
| 82 | + comp = Blosc(cname="zstd", clevel=5, shuffle="shuffle") |
| 83 | + |
| 84 | + # Replace placeholders |
| 85 | + if chunk_grid == "grid": |
| 86 | + chunk_grid = grid |
| 87 | + if compressor == "comp": |
| 88 | + compressor = comp |
| 89 | + |
| 90 | + if should_pass: |
| 91 | + _validate_inputs(new_variable, chunk_grid, compressor) # type: ignore[arg-type] |
| 92 | + else: |
| 93 | + with pytest.raises((TypeError, ValueError)): |
| 94 | + _validate_inputs(new_variable, chunk_grid, compressor) # type: ignore[arg-type] |
| 95 | + |
| 96 | + |
| 97 | +class TestNormalizeNewVariable: |
| 98 | + """Tests for _normalize_new_variable helper function.""" |
| 99 | + |
| 100 | + @pytest.mark.parametrize( |
| 101 | + ("input_value", "expected"), |
| 102 | + [ |
| 103 | + ("var1", ["var1"]), |
| 104 | + (["var1", "var2"], ["var1", "var2"]), |
| 105 | + ], |
| 106 | + ) |
| 107 | + def test_normalize(self, input_value: str | list[str], expected: list[str]) -> None: |
| 108 | + """Test new_variable normalization.""" |
| 109 | + result = _normalize_new_variable(input_value) |
| 110 | + assert result == expected |
| 111 | + |
| 112 | + |
| 113 | +class TestNormalizeChunkGrid: |
| 114 | + """Tests for _normalize_chunk_grid helper function.""" |
| 115 | + |
| 116 | + def test_broadcast_and_match(self) -> None: |
| 117 | + """Test chunk grid broadcasting and matching.""" |
| 118 | + grid = RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=(10, 10, 10))) |
| 119 | + grids = [grid] * 3 |
| 120 | + |
| 121 | + # Single grid broadcasts |
| 122 | + assert len(_normalize_chunk_grid(grid, 3)) == 3 |
| 123 | + |
| 124 | + # List matches length |
| 125 | + result = _normalize_chunk_grid(grids, 3) |
| 126 | + assert len(result) == 3 |
| 127 | + assert result == grids |
| 128 | + |
| 129 | + # Mismatch raises error |
| 130 | + with pytest.raises(ValueError): |
| 131 | + _normalize_chunk_grid([grid, grid], 3) |
| 132 | + |
| 133 | + |
| 134 | +class TestNormalizeCompressor: |
| 135 | + """Tests for _normalize_compressor helper function.""" |
| 136 | + |
| 137 | + def test_broadcast_and_match(self) -> None: |
| 138 | + """Test compressor broadcasting and matching.""" |
| 139 | + comp = Blosc(cname="zstd", clevel=5, shuffle="shuffle") |
| 140 | + |
| 141 | + # None broadcasts |
| 142 | + assert all(c is None for c in _normalize_compressor(None, 3)) |
| 143 | + |
| 144 | + # Single compressor broadcasts |
| 145 | + assert len(_normalize_compressor(comp, 3)) == 3 |
| 146 | + |
| 147 | + # List with None entries |
| 148 | + result = _normalize_compressor([comp, None, comp], 3) |
| 149 | + assert result[1] is None |
| 150 | + |
| 151 | + # Mismatch raises error |
| 152 | + with pytest.raises(ValueError): |
| 153 | + _normalize_compressor([comp, comp], 3) |
0 commit comments