diff --git a/virtualizarr/parsers/hdf/hdf.py b/virtualizarr/parsers/hdf/hdf.py index d9cef4ff..874bd153 100644 --- a/virtualizarr/parsers/hdf/hdf.py +++ b/virtualizarr/parsers/hdf/hdf.py @@ -35,10 +35,16 @@ from h5py import Group as H5Group +def _squeeze_indices(shape: tuple) -> list[int]: + """Return indices of dimensions where shape is greater than 1.""" + return [i for i, s in enumerate(shape) if s > 1] + + def _construct_manifest_array( filepath: str, dataset: H5Dataset, group: str, + squeeze: bool = False, ) -> ManifestArray: """ Construct a ManifestArray from an h5py dataset @@ -60,6 +66,15 @@ def _construct_manifest_array( # chunk dimensions (enforced by zarr-python >= 3.2.0). See # https://github.com/zarr-developers/zarr-python/issues/3711. chunks = dataset.chunks or tuple(max(s, 1) for s in dataset.shape) + # When squeeze=True, we keep only dimensions of size > 1. + # So squeeze=True on its own drops any dim that was length 0 (or less) in the original dataset and the clamp is technically redundant. + # But when squeeze=False, the clamp is necessary to prevent having zero-length chunk dimensions. + keep_indices = ( + _squeeze_indices(dataset.shape) if squeeze else list(range(len(dataset.shape))) + ) + keep_chunks = tuple(chunks[i] for i in keep_indices) + keep_shape = tuple(dataset.shape[i] for i in keep_indices) + codecs = codecs_from_dataset(dataset) attrs = _extract_attrs(dataset) dtype = dataset.dtype @@ -82,16 +97,17 @@ def _construct_manifest_array( fill_value = dataset.fillvalue.item() dims = tuple(_dataset_dims(dataset, group=group)) + keep_dims = tuple(dims[i] for i in keep_indices) metadata = create_v3_array_metadata( - shape=dataset.shape, + shape=keep_shape, data_type=dtype, - chunk_shape=chunks, + chunk_shape=keep_chunks, fill_value=fill_value, codecs=codec_configs, - dimension_names=dims, + dimension_names=keep_dims, attributes=attrs, ) - manifest = _dataset_chunk_manifest(filepath, dataset) + manifest = _dataset_chunk_manifest(filepath, dataset, squeeze=squeeze) return ManifestArray(metadata=metadata, chunkmanifest=manifest) @@ -101,6 +117,7 @@ def _construct_manifest_group( *, group: str | None = None, drop_variables: Iterable[str] | None = None, + squeeze: bool = False, ) -> ManifestGroup: """ Construct a virtual Group from a HDF dataset. @@ -120,7 +137,9 @@ def _construct_manifest_group( drop_variables = set(drop_variables or ()) | set(non_coordinate_dimension_vars) group_name = str(g.name) # NOTE: this will always include leading "/" arrays = { - key: _construct_manifest_array(filepath, dataset, group_name) + key: _construct_manifest_array( + filepath, dataset, group_name, squeeze=squeeze + ) for key in g.keys() if key not in drop_variables if isinstance(dataset := g[key], h5py.Dataset) @@ -130,6 +149,7 @@ def _construct_manifest_group( filepath, reader, group=str(Path(group) / key) if group is not None else key, + squeeze=squeeze, ) for key in g.keys() if key not in drop_variables @@ -146,6 +166,7 @@ def __init__( group: str | None = None, drop_variables: Iterable[str] | None = None, reader_factory: ReaderFactory = BlockStoreReader, + squeeze: bool = False, ): """ Instantiate a parser that can be used to virtualize HDF5/NetCDF4 files using the @@ -163,10 +184,13 @@ def __init__( Must return an object implementing the [ReadableFile][obspec_utils.protocols.ReadableFile] protocol. Default is [BlockStoreReader][obspec_utils.readers.BlockStoreReader]. + squeeze + If `True`, remove dimensions of size 1 from arrays (default: `False`). """ self.group = group self.drop_variables = drop_variables self.reader_factory = reader_factory + self.squeeze = squeeze def __call__( self, @@ -196,6 +220,7 @@ def __call__( reader=reader, group=self.group, drop_variables=self.drop_variables, + squeeze=self.squeeze, ) # Convert to a manifest store return ManifestStore(registry=registry, group=manifest_group) @@ -204,6 +229,7 @@ def __call__( def _dataset_chunk_manifest( filepath: str, dataset: H5Dataset, + squeeze: bool = False, ) -> ChunkManifest: """ Generate ChunkManifest for HDF5 dataset. @@ -221,6 +247,9 @@ def _dataset_chunk_manifest( A Virtualizarr ChunkManifest """ dsid = dataset.id + keep_indices = ( + _squeeze_indices(dataset.shape) if squeeze else list(range(len(dataset.shape))) + ) if dataset.chunks is None: if dsid.get_offset() is None: chunk_manifest = ChunkManifest(entries={}, shape=dataset.shape) @@ -231,7 +260,7 @@ def _dataset_chunk_manifest( lengths=np.array(dsid.get_storage_size(), dtype=np.uint64), ) else: - key_list = [0] * (len(dataset.shape) or 1) + key_list = [0] * (len(keep_indices) or 1) key = ".".join(map(str, key_list)) chunk_entry: ChunkEntry = ChunkEntry.with_validation( # type: ignore[attr-defined] @@ -268,6 +297,15 @@ def add_chunk_info(blob): for index in range(num_chunks): add_chunk_info(dsid.get_chunk_info(index)) + # we squeeze here rather than in get_key + squeeze_axes = tuple( + i for i in range(len(dataset.chunks)) if i not in set(keep_indices) + ) + if squeeze_axes: + paths = np.squeeze(paths, axis=squeeze_axes) + offsets = np.squeeze(offsets, axis=squeeze_axes) + lengths = np.squeeze(lengths, axis=squeeze_axes) + chunk_manifest = ChunkManifest.from_arrays( paths=paths, # type: ignore offsets=offsets, diff --git a/virtualizarr/tests/test_parsers/conftest.py b/virtualizarr/tests/test_parsers/conftest.py index 23d30a4f..78fb1aff 100644 --- a/virtualizarr/tests/test_parsers/conftest.py +++ b/virtualizarr/tests/test_parsers/conftest.py @@ -479,6 +479,47 @@ def big_endian_dtype_hdf5_file(tmpdir): return filepath +@pytest.fixture( + params=[ + {"N": 50, "M": 100, "chunked": True, "chunks": (5, 25)}, + {"N": 50, "M": 100, "chunked": False, "chunks": None}, + {"N": 1, "M": 100, "chunked": True, "chunks": (1, 25)}, + ], + ids=["chunked", "not_chunked", "singleton_chunked"], +) +def singleton_padded_dimension_hdf5_file(tmp_path: Path, request) -> tuple: + """HDF5 file mimicking MATLAB layout: a 2D array (N, M) plus coordinate + arrays shaped (N, 1) and (1, M).""" + N = request.param["N"] + M = request.param["M"] + chunked = request.param["chunked"] + chunks = request.param["chunks"] + filepath = str(tmp_path / "singleton_dimension_layout.nc") + + dataset_args = { + "data": { + "name": "data", + "data": np.random.random((N, M)), + }, + "row_coord": { + "name": "row_coord", + "data": np.random.random((N, 1)), + }, + "col_coord": { + "name": "col_coord", + "data": np.random.random((1, M)), + }, + } + if chunks is not None: + dataset_args["data"]["chunks"] = chunks + + with h5py.File(filepath, "w") as f: + for v in dataset_args.values(): + f.create_dataset(**v) + + return f"file://{filepath}", N, M, chunked, chunks + + @pytest.fixture() def dmrpp_xml_simple(): """Return a minimal valid DMR++ XML string for testing.""" diff --git a/virtualizarr/tests/test_parsers/test_hdf/test_hdf.py b/virtualizarr/tests/test_parsers/test_hdf/test_hdf.py index 9854174f..7bd270fe 100644 --- a/virtualizarr/tests/test_parsers/test_hdf/test_hdf.py +++ b/virtualizarr/tests/test_parsers/test_hdf/test_hdf.py @@ -41,6 +41,32 @@ def test_chunked_roundtrip(self, chunked_roundtrip_hdf5_url): manifest_store = manifest_store_from_hdf_url(chunked_roundtrip_hdf5_url) assert manifest_store._group.arrays["var2"].manifest.shape_chunk_grid == (2, 8) + def test_singleton_dimensions_squeezed(self, singleton_padded_dimension_hdf5_file): + url, N, M, chunked, chunks = singleton_padded_dimension_hdf5_file + manifest_store = manifest_store_from_hdf_url(url, squeeze=True) + expected_data_shape = (M,) if N == 1 else (N, M) + assert manifest_store._group.arrays["data"].shape == expected_data_shape + if chunked: + expected_chunks = tuple(c for c, n in zip(chunks, (N, M)) if n != 1) + else: + expected_chunks = expected_data_shape + assert manifest_store._group.arrays["data"].chunks == expected_chunks + assert manifest_store._group.arrays["row_coord"].shape == ( + () if N == 1 else (N,) + ) + assert manifest_store._group.arrays["col_coord"].shape == (M,) + + def test_singleton_dimensions_not_squeezed( + self, singleton_padded_dimension_hdf5_file + ): + url, N, M, chunked, chunks = singleton_padded_dimension_hdf5_file + manifest_store = manifest_store_from_hdf_url(url) + assert manifest_store._group.arrays["data"].shape == (N, M) + expected_chunks = chunks if chunked else (N, M) + assert manifest_store._group.arrays["data"].chunks == expected_chunks + assert manifest_store._group.arrays["row_coord"].shape == (N, 1) + assert manifest_store._group.arrays["col_coord"].shape == (1, M) + @requires_hdf5plugin @requires_imagecodecs diff --git a/virtualizarr/tests/test_parsers/test_hdf/test_hdf_integration.py b/virtualizarr/tests/test_parsers/test_hdf/test_hdf_integration.py index e779317c..b195549b 100644 --- a/virtualizarr/tests/test_parsers/test_hdf/test_hdf_integration.py +++ b/virtualizarr/tests/test_parsers/test_hdf/test_hdf_integration.py @@ -96,6 +96,32 @@ def test_non_coord_dim_roundtrip(self, tmp_path, non_coord_dim, local_registry): with xr.open_dataset(kerchunk_file, engine="kerchunk") as roundtrip: xrt.assert_equal(ds, roundtrip) + def test_singleton_dim_roundtrip( + self, tmp_path, singleton_padded_dimension_hdf5_file, local_registry + ): + import numpy as np + + parser = HDFParser(squeeze=True) + filepath, *_ = singleton_padded_dimension_hdf5_file + with ( + xr.open_dataset( + filepath, engine="h5netcdf", backend_kwargs={"phony_dims": "sort"} + ).squeeze(drop=True) as ds, + open_virtual_dataset( + url=filepath, + registry=local_registry, + parser=parser, + ) as vds, + ): + kerchunk_file = str(tmp_path / "kerchunk.json") + vds.vz.to_kerchunk(kerchunk_file, format="json") + with xr.open_dataset(kerchunk_file, engine="kerchunk") as roundtrip: + for var_name in ds.data_vars: + assert ds[var_name].shape == roundtrip[var_name].shape + np.testing.assert_array_equal( + ds[var_name].values, roundtrip[var_name].values + ) + @requires_icechunk def test_cf_fill_value_roundtrip( self, tmp_path, cf_fill_value_hdf5_file, local_registry diff --git a/virtualizarr/tests/utils.py b/virtualizarr/tests/utils.py index 4ce2fece..7d6c1814 100644 --- a/virtualizarr/tests/utils.py +++ b/virtualizarr/tests/utils.py @@ -40,8 +40,8 @@ def obstore_http(url: str) -> ObjectStore: return store -def manifest_store_from_hdf_url(url, group: str | None = None): +def manifest_store_from_hdf_url(url, group: str | None = None, squeeze: bool = False): registry: ObjectStoreRegistry = ObjectStoreRegistry() registry.register(url, obstore_local(url=url)) - parser = HDFParser(group=group) + parser = HDFParser(group=group, squeeze=squeeze) return parser(url=url, registry=registry)