Skip to content

Commit 1a95f82

Browse files
committed
create_empty_like
1 parent 325466f commit 1a95f82

File tree

12 files changed

+268
-77
lines changed

12 files changed

+268
-77
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ dev = [
5555
"pre-commit-hooks>=6.0.0",
5656
"pytest>=8.4.2",
5757
"pytest-dependency>=0.6.0",
58+
"pytest-order>=1.3.0",
5859
"typeguard>=4.4.4",
5960
"xdoctest[colors]>=1.3.0",
6061
"Pygments>=2.19.2"

src/mdio/api/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Public API."""
22

33
from mdio.api.create import create_empty
4+
from mdio.api.create import create_empty_like
45

5-
__all__ = ["create_empty"]
6+
__all__ = ["create_empty", "create_empty_like"]

src/mdio/api/create.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22

33
from __future__ import annotations
44

5+
from datetime import UTC
6+
from datetime import datetime
57
from typing import TYPE_CHECKING
68

79
from mdio.api.io import _normalize_path
10+
from mdio.api.io import open_mdio
811
from mdio.api.io import to_mdio
912
from mdio.builder.template_registry import TemplateRegistry
1013
from mdio.builder.xarray_builder import to_xarray_dataset
@@ -84,3 +87,81 @@ def create_empty( # noqa PLR0913
8487
meta_ds = dataset[drop_vars_delayed + ["trace_mask"]]
8588
to_mdio(meta_ds, output_path=output_path, mode="r+", compute=True)
8689

90+
91+
def create_empty_like( # noqa PLR0913
92+
input_path: UPath | Path | str,
93+
output_path: UPath | Path | str,
94+
keep_coordinates: bool = False,
95+
overwrite: bool = False,
96+
) -> xr_Dataset:
97+
"""A function that creates an empty MDIO v1 file with the same structure as an existing one.
98+
99+
Args:
100+
input_path: The path of the input MDIO file.
101+
output_path: The path of the output MDIO file.
102+
If None, the output will not be written to disk.
103+
keep_coordinates: Whether to keep the coordinates in the output file.
104+
overwrite: Whether to overwrite the output file if it exists.
105+
106+
Returns:
107+
The output MDIO dataset.
108+
109+
Raises:
110+
FileExistsError: If the output location already exists and overwrite is False.
111+
"""
112+
input_path = _normalize_path(input_path)
113+
output_path = _normalize_path(output_path) if output_path is not None else None
114+
115+
if not overwrite and output_path is not None and output_path.exists():
116+
err = f"Output location '{output_path.as_posix()}' exists. Set `overwrite=True` if intended."
117+
raise FileExistsError(err)
118+
119+
ds = open_mdio(input_path)
120+
121+
# Create a copy with the same structure but no data or,
122+
# optionally, coordinates
123+
ds_output = ds.copy(data=None).reset_coords(drop=not keep_coordinates)
124+
125+
# Dataset
126+
# Keep the name (which is the same as the used template name) and the original API version
127+
# ds_output.attrs["name"]
128+
# ds_output.attrs["apiVersion"]
129+
ds_output.attrs["createdOn"] = datetime.now(UTC)
130+
131+
# Coordinates
132+
if not keep_coordinates:
133+
for coord_name in ds_output.coords:
134+
ds_output[coord_name].attrs["unitsV1"] = None
135+
136+
# MDIO attributes
137+
attr = ds_output.attrs["attributes"]
138+
if attr is not None:
139+
attr.pop("gridOverrides", None) # Empty dataset should not have gridOverrides
140+
# Keep the original values for the following attributes
141+
# attr["defaultVariableName"]
142+
# attr["surveyType"]
143+
# attr["gatherType"]
144+
145+
# "All traces should be marked as dead in empty dataset"
146+
if "trace_mask" in ds_output.variables:
147+
ds_output["trace_mask"][:] = False
148+
149+
# Data variable
150+
var_name = attr["defaultVariableName"]
151+
var = ds_output[var_name]
152+
var.attrs["statsV1"] = None
153+
if not keep_coordinates:
154+
var.attrs["unitsV1"] = None
155+
156+
# SEG-Y file header
157+
if "segy_file_header" in ds_output.variables:
158+
segy_file_header = ds_output["segy_file_header"]
159+
if segy_file_header is not None:
160+
segy_file_header.attrs["textHeader"] = None
161+
segy_file_header.attrs["binaryHeader"] = None
162+
segy_file_header.attrs["rawBinaryHeader"] = None
163+
164+
if output_path is not None:
165+
to_mdio(ds_output, output_path=output_path, mode="w", compute=True)
166+
167+
return ds_output

src/mdio/builder/dataset_builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from enum import auto
77
from typing import Any
88

9-
from mdio import __version__
109
from mdio.builder.formatting_html import dataset_builder_repr_html
1110
from mdio.builder.schemas.compressors import ZFP
1211
from mdio.builder.schemas.compressors import Blosc
@@ -59,6 +58,8 @@ class MDIODatasetBuilder:
5958
"""
6059

6160
def __init__(self, name: str, attributes: dict[str, Any] | None = None):
61+
from mdio import __version__ # noqa: PLC0415 - fixed circular import in mdio package and dataset_builder.py
62+
6263
self._metadata = DatasetMetadata(
6364
name=name,
6465
api_version=__version__,

src/mdio/converters/__init__.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,30 @@
11
"""MDIO Data conversion API."""
22

3-
from mdio.converters.mdio import mdio_to_segy
4-
from mdio.converters.segy import segy_to_mdio
3+
from typing import TYPE_CHECKING
4+
from typing import Any
5+
6+
if TYPE_CHECKING:
7+
from mdio.converters.mdio import mdio_to_segy
8+
from mdio.converters.segy import segy_to_mdio
59

610
__all__ = ["mdio_to_segy", "segy_to_mdio"]
11+
12+
13+
def __getattr__(name: str) -> Any: # noqa: ANN401 - required for dynamic attribute access
14+
"""Lazy import for converters to avoid circular imports."""
15+
if name == "mdio_to_segy":
16+
from mdio.converters.mdio import ( # noqa: PLC0415 - intentionally inside the function to avoid circular imports
17+
mdio_to_segy,
18+
)
19+
20+
return mdio_to_segy
21+
22+
if name == "segy_to_mdio":
23+
from mdio.converters.segy import ( # noqa: PLC0415 - intentionally inside the function to avoid circular imports
24+
segy_to_mdio,
25+
)
26+
27+
return segy_to_mdio
28+
29+
err = f"module {__name__!r} has no attribute {name!r}"
30+
raise AttributeError(err)

tests/conftest.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,18 @@ def zarr_tmp(tmp_path_factory: pytest.TempPathFactory) -> Path:
4747
return tmp_path_factory.mktemp(r"mdio")
4848

4949

50+
@pytest.fixture(scope="session")
51+
def teapot_mdio_tmp(tmp_path_factory: pytest.TempPathFactory) -> Path:
52+
"""Make a temp file for the output MDIO."""
53+
return tmp_path_factory.mktemp(r"teapot.mdio")
54+
55+
56+
@pytest.fixture(scope="module")
57+
def mdio_4d_tmp(tmp_path_factory: pytest.TempPathFactory) -> Path:
58+
"""Make a temp file for the output MDIO."""
59+
return tmp_path_factory.mktemp(r"tmp_4d.mdio")
60+
61+
5062
@pytest.fixture(scope="module")
5163
def zarr_tmp2(tmp_path_factory: pytest.TempPathFactory) -> Path: # pragma: no cover - used by disabled test
5264
"""Make a temp file for the output MDIO."""
@@ -63,4 +75,28 @@ def segy_export_tmp(tmp_path_factory: pytest.TempPathFactory) -> Path:
6375
@pytest.fixture(scope="class")
6476
def empty_mdio_dir(tmp_path_factory: pytest.TempPathFactory) -> Path:
6577
"""Make a temp file for empty MDIO testing."""
66-
return tmp_path_factory.mktemp(r"empty_mdio")
78+
return tmp_path_factory.mktemp(r"empty_mdio_dir")
79+
80+
81+
#
82+
# Uncomment the function below for local debugging
83+
#
84+
# @pytest.fixture(scope="session")
85+
# def tmp_path_factory() -> pytest.TempPathFactory:
86+
# """Custom tmp_path_factory implementation for local debugging."""
87+
# from pathlib import Path # noqa: PLC0415
88+
89+
# class DebugTempPathFactory:
90+
# def __init__(self) -> None:
91+
# self._retention_policy = "all"
92+
93+
# def mktemp(self, basename: str, numbered: bool = True) -> Path:
94+
# _ = numbered
95+
# path = self.getbasetemp() / basename
96+
# path.mkdir(parents=True, exist_ok=True)
97+
# return path
98+
99+
# def getbasetemp(self) -> Path:
100+
# return Path("tmp")
101+
102+
# return DebugTempPathFactory()

0 commit comments

Comments
 (0)