Skip to content

Commit 73bf936

Browse files
committed
Begin adding tests
1 parent 368a472 commit 73bf936

2 files changed

Lines changed: 298 additions & 0 deletions

File tree

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""Integration tests for chunking module."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
import dask
8+
import numpy as np
9+
import pytest
10+
from segy.factory import SegyFactory
11+
from segy.schema import HeaderField
12+
from segy.schema import ScalarType
13+
from segy.standards import get_segy_standard
14+
15+
from mdio.api.io import open_mdio
16+
from mdio.builder.schemas.chunk_grid import RegularChunkGrid
17+
from mdio.builder.schemas.chunk_grid import RegularChunkShape
18+
from mdio.builder.schemas.compressors import Blosc
19+
from mdio.builder.template_registry import TemplateRegistry
20+
from mdio.converters.segy import segy_to_mdio
21+
from mdio.transpose_writers.chunking import from_variable
22+
23+
if TYPE_CHECKING:
24+
from pathlib import Path
25+
26+
from segy.schema import SegySpec
27+
28+
29+
dask.config.set(scheduler="synchronous")
30+
31+
32+
@pytest.fixture
33+
def synthetic_segy_spec() -> SegySpec:
34+
"""SEG-Y specification for synthetic testing."""
35+
trace_header_fields = [
36+
HeaderField(name="inline", byte=189, format=ScalarType.INT32),
37+
HeaderField(name="crossline", byte=193, format=ScalarType.INT32),
38+
HeaderField(name="cdp_x", byte=181, format=ScalarType.INT32),
39+
HeaderField(name="cdp_y", byte=185, format=ScalarType.INT32),
40+
]
41+
return get_segy_standard(1.0).customize(trace_header_fields=trace_header_fields)
42+
43+
44+
@pytest.fixture
45+
def synthetic_segy_file(fake_segy_tmp: Path, synthetic_segy_spec: SegySpec) -> Path:
46+
"""Create a small synthetic 3D SEG-Y file."""
47+
segy_path = fake_segy_tmp / "synthetic_3d.sgy"
48+
inlines, crosslines, num_samples = np.arange(1, 9), np.arange(1, 17), 64
49+
num_traces = len(inlines) * len(crosslines)
50+
51+
factory = SegyFactory(spec=synthetic_segy_spec, sample_interval=4000, samples_per_trace=num_samples)
52+
headers = factory.create_trace_header_template(num_traces)
53+
samples = factory.create_trace_sample_template(num_traces)
54+
55+
trace_idx = 0
56+
for inline in inlines:
57+
for crossline in crosslines:
58+
headers["inline"][trace_idx] = inline
59+
headers["crossline"][trace_idx] = crossline
60+
headers["cdp_x"][trace_idx] = inline * 100
61+
headers["cdp_y"][trace_idx] = crossline * 100
62+
headers["coordinate_scalar"][trace_idx] = -100
63+
samples[trace_idx] = np.sin(np.linspace(0, 2 * np.pi * inline, num_samples)) * crossline
64+
trace_idx += 1
65+
66+
with segy_path.open(mode="wb") as fp:
67+
fp.write(factory.create_textual_header())
68+
fp.write(factory.create_binary_header())
69+
fp.write(factory.create_traces(headers, samples))
70+
71+
return segy_path
72+
73+
74+
@pytest.fixture
75+
def mdio_dataset(synthetic_segy_file: Path, synthetic_segy_spec: SegySpec, zarr_tmp: Path) -> Path:
76+
"""Convert synthetic SEG-Y to MDIO."""
77+
mdio_path = zarr_tmp / "test_dataset.mdio"
78+
segy_to_mdio(
79+
segy_spec=synthetic_segy_spec,
80+
mdio_template=TemplateRegistry().get("PostStack3DTime"),
81+
input_path=synthetic_segy_file,
82+
output_path=mdio_path,
83+
overwrite=True,
84+
)
85+
return mdio_path
86+
87+
88+
def test_single_variable_rechunk(mdio_dataset: Path) -> None:
89+
"""Test creating a single rechunked variable."""
90+
chunk_grid = RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=(2, 16, 64)))
91+
compressor = Blosc(cname="zstd", clevel=5, shuffle="shuffle")
92+
93+
from_variable(
94+
dataset_path=mdio_dataset,
95+
source_variable="amplitude",
96+
new_variable="fast_inline",
97+
chunk_grid=chunk_grid,
98+
compressor=compressor,
99+
copy_metadata=True,
100+
)
101+
102+
ds = open_mdio(mdio_dataset)
103+
104+
# Verify new variable exists with correct chunks
105+
assert "fast_inline" in ds.data_vars
106+
assert ds["fast_inline"].encoding["chunks"] == (2, 16, 64)
107+
108+
# Verify data integrity
109+
np.testing.assert_array_equal(ds["amplitude"].values, ds["fast_inline"].values)
110+
111+
# Verify metadata copied
112+
assert ds["fast_inline"].attrs == ds["amplitude"].attrs
113+
114+
115+
def test_multiple_variables_with_broadcasting(mdio_dataset: Path) -> None:
116+
"""Test creating multiple variables with different settings and broadcasting."""
117+
new_variables = ["fast_inline", "fast_crossline", "fast_time"]
118+
chunk_grids = [
119+
RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=(2, 16, 64))),
120+
RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=(8, 4, 64))),
121+
RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=(8, 16, 16))),
122+
]
123+
compressor = Blosc(cname="lz4", clevel=3, shuffle="shuffle") # Single compressor broadcasts
124+
125+
from_variable(
126+
dataset_path=mdio_dataset,
127+
source_variable="amplitude",
128+
new_variable=new_variables,
129+
chunk_grid=chunk_grids,
130+
compressor=compressor,
131+
copy_metadata=False,
132+
)
133+
134+
ds = open_mdio(mdio_dataset)
135+
136+
# Verify all variables created with correct chunks
137+
expected_chunks = [(2, 16, 64), (8, 4, 64), (8, 16, 16)]
138+
for var_name, chunks in zip(new_variables, expected_chunks, strict=True):
139+
assert var_name in ds.data_vars
140+
assert ds[var_name].encoding["chunks"] == chunks
141+
assert "compressors" in ds[var_name].encoding
142+
np.testing.assert_array_equal(ds["amplitude"].values, ds[var_name].values)
143+
144+
# Metadata should not be copied
145+
assert len(ds[var_name].attrs) == 0 or ds[var_name].attrs != ds["amplitude"].attrs
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
"""Unit tests for chunking module helper functions."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
import numpy as np
8+
import pytest
9+
import xarray as xr
10+
import zarr
11+
12+
from mdio.api.io import open_mdio
13+
from mdio.api.io import to_mdio
14+
from mdio.builder.schemas.chunk_grid import RegularChunkGrid
15+
from mdio.builder.schemas.chunk_grid import RegularChunkShape
16+
from mdio.builder.schemas.compressors import Blosc
17+
from mdio.constants import ZarrFormat
18+
from mdio.transpose_writers.chunking import _normalize_chunk_grid
19+
from mdio.transpose_writers.chunking import _normalize_compressor
20+
from mdio.transpose_writers.chunking import _normalize_new_variable
21+
from mdio.transpose_writers.chunking import _remove_fillvalue_attrs
22+
from mdio.transpose_writers.chunking import _validate_inputs
23+
24+
if TYPE_CHECKING:
25+
from pathlib import Path
26+
27+
28+
class TestRemoveFillvalueAttrs:
29+
"""Tests for _remove_fillvalue_attrs helper function."""
30+
31+
@pytest.mark.parametrize("zarr_format", [ZarrFormat.V2, ZarrFormat.V3])
32+
def test_remove_fillvalue_after_mdio_serialization(self, tmp_path: Path, zarr_format: ZarrFormat) -> None:
33+
"""Test that _FillValue is removed after MDIO serialization in both Zarr v2 and v3."""
34+
# Create dataset with NaN values (will add _FillValue on serialization)
35+
data = np.array([[1.0, 2.0, np.nan], [3.0, np.nan, 4.0]], dtype=np.float32)
36+
ds = xr.Dataset(
37+
{"var1": (["x", "y"], data, {"units": "meters"})},
38+
coords={"x": (["x"], [0, 1], {"axis": "X"})},
39+
)
40+
41+
# Write and read back with MDIO
42+
mdio_path = tmp_path / f"test_{zarr_format}.mdio"
43+
with zarr.config.set(default_zarr_format=zarr_format):
44+
to_mdio(ds, mdio_path, mode="w")
45+
ds_read = open_mdio(mdio_path)
46+
47+
# Apply function and verify _FillValue is removed everywhere
48+
_remove_fillvalue_attrs(ds_read)
49+
50+
for var_name in list(ds_read.data_vars) + list(ds_read.coords):
51+
assert "_FillValue" not in ds_read[var_name].attrs
52+
53+
# Verify other attributes preserved
54+
assert ds_read["var1"].attrs["units"] == "meters"
55+
assert ds_read["x"].attrs["axis"] == "X"
56+
57+
58+
class TestValidateInputs:
59+
"""Tests for _validate_inputs helper function."""
60+
61+
@pytest.mark.parametrize(
62+
("new_variable", "chunk_grid", "compressor", "should_pass"),
63+
[
64+
("var1", "grid", "comp", True),
65+
(["var1", "var2"], "grid", "comp", True),
66+
("var1", "grid", None, True),
67+
(123, "grid", "comp", False),
68+
([], "grid", "comp", False),
69+
("var1", [], "comp", False),
70+
("var1", "grid", [], False),
71+
],
72+
)
73+
def test_validation(
74+
self,
75+
new_variable: str | list | int,
76+
chunk_grid: str | list,
77+
compressor: str | list | None,
78+
should_pass: bool,
79+
) -> None:
80+
"""Test input validation with various combinations."""
81+
grid = RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=(10, 10)))
82+
comp = Blosc(cname="zstd", clevel=5, shuffle="shuffle")
83+
84+
# Replace placeholders
85+
if chunk_grid == "grid":
86+
chunk_grid = grid
87+
if compressor == "comp":
88+
compressor = comp
89+
90+
if should_pass:
91+
_validate_inputs(new_variable, chunk_grid, compressor) # type: ignore[arg-type]
92+
else:
93+
with pytest.raises((TypeError, ValueError)):
94+
_validate_inputs(new_variable, chunk_grid, compressor) # type: ignore[arg-type]
95+
96+
97+
class TestNormalizeNewVariable:
98+
"""Tests for _normalize_new_variable helper function."""
99+
100+
@pytest.mark.parametrize(
101+
("input_value", "expected"),
102+
[
103+
("var1", ["var1"]),
104+
(["var1", "var2"], ["var1", "var2"]),
105+
],
106+
)
107+
def test_normalize(self, input_value: str | list[str], expected: list[str]) -> None:
108+
"""Test new_variable normalization."""
109+
result = _normalize_new_variable(input_value)
110+
assert result == expected
111+
112+
113+
class TestNormalizeChunkGrid:
114+
"""Tests for _normalize_chunk_grid helper function."""
115+
116+
def test_broadcast_and_match(self) -> None:
117+
"""Test chunk grid broadcasting and matching."""
118+
grid = RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=(10, 10, 10)))
119+
grids = [grid] * 3
120+
121+
# Single grid broadcasts
122+
assert len(_normalize_chunk_grid(grid, 3)) == 3
123+
124+
# List matches length
125+
result = _normalize_chunk_grid(grids, 3)
126+
assert len(result) == 3
127+
assert result == grids
128+
129+
# Mismatch raises error
130+
with pytest.raises(ValueError):
131+
_normalize_chunk_grid([grid, grid], 3)
132+
133+
134+
class TestNormalizeCompressor:
135+
"""Tests for _normalize_compressor helper function."""
136+
137+
def test_broadcast_and_match(self) -> None:
138+
"""Test compressor broadcasting and matching."""
139+
comp = Blosc(cname="zstd", clevel=5, shuffle="shuffle")
140+
141+
# None broadcasts
142+
assert all(c is None for c in _normalize_compressor(None, 3))
143+
144+
# Single compressor broadcasts
145+
assert len(_normalize_compressor(comp, 3)) == 3
146+
147+
# List with None entries
148+
result = _normalize_compressor([comp, None, comp], 3)
149+
assert result[1] is None
150+
151+
# Mismatch raises error
152+
with pytest.raises(ValueError):
153+
_normalize_compressor([comp, comp], 3)

0 commit comments

Comments
 (0)