Skip to content

Commit 251e2f6

Browse files
committed
Merge upstream/main' into create_empty
1 parent 7a65c07 commit 251e2f6

File tree

4 files changed

+138
-51
lines changed

4 files changed

+138
-51
lines changed

src/mdio/creators/mdio.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@
2020
from xarray import Dataset as xr_Dataset
2121

2222
from mdio.builder.schemas import Dataset
23+
from mdio.builder.templates.base import AbstractDatasetTemplate
2324
from mdio.core.dimension import Dimension
2425

2526

2627
def create_empty( # noqa PLR0913
27-
mdio_template_name: str,
28+
mdio_template: AbstractDatasetTemplate | str,
2829
dimensions: list[Dimension],
2930
output_path: UPath | Path | str,
3031
headers: HeaderSpec | None = None,
@@ -33,7 +34,14 @@ def create_empty( # noqa PLR0913
3334
"""A function that creates an empty MDIO v1 file with known dimensions.
3435
3536
Args:
36-
mdio_template_name: The MDIO template to use to define the dataset structure.
37+
mdio_template: The MDIO template or template name to use to define the dataset structure.
38+
NOTE: If you want to have a unit-aware MDIO model, you need to add the units
39+
to the template before calling this function. For example:
40+
'unit_aware_template = TemplateRegistry().get("PostStack3DTime")'
41+
'unit_aware_template.add_units({"time": UNITS_SECOND})'
42+
'unit_aware_template.add_units({"cdp_x": UNITS_METER})'
43+
'unit_aware_template.add_units({"cdp_y": UNITS_METER})'
44+
'create_empty(unit_aware_template, dimensions, output_path, headers, overwrite)'
3745
dimensions: The dimensions of the MDIO file.
3846
output_path: The universal path for the output MDIO v1 file.
3947
headers: SEG-Y v1.0 trace headers. Defaults to None.
@@ -50,13 +58,11 @@ def create_empty( # noqa PLR0913
5058

5159
header_dtype = to_structured_type(headers.dtype) if headers else None
5260
grid = Grid(dims=dimensions)
53-
mdio_template = TemplateRegistry().get(mdio_template_name)
54-
mdio_ds: Dataset = mdio_template.build_dataset(
55-
name=mdio_template_name,
56-
sizes=grid.shape,
57-
horizontal_coord_unit=None,
58-
header_dtype=header_dtype,
59-
)
61+
if isinstance(mdio_template, str):
62+
# A template name is passed in. Get a unit-unaware template from registry
63+
mdio_template = TemplateRegistry().get(mdio_template)
64+
# Build the dataset using the template
65+
mdio_ds: Dataset = mdio_template.build_dataset(name=mdio_template.name, sizes=grid.shape, header_dtype=header_dtype)
6066

6167
# Convert to xarray dataset
6268
xr_dataset: xr_Dataset = to_xarray_dataset(mdio_ds=mdio_ds)

tests/integration/test_create_empty.py

Lines changed: 78 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,61 @@
2626

2727

2828
from tests.integration.testing_helpers import get_values
29-
from tests.integration.testing_helpers import validate_variable
29+
from tests.integration.testing_helpers import validate_xr_variable
3030

3131
from mdio import __version__
3232
from mdio.api.io import open_mdio
3333
from mdio.api.io import to_mdio
3434
from mdio.builder.schemas.v1.stats import CenteredBinHistogram
3535
from mdio.builder.schemas.v1.stats import SummaryStatistics
36+
from mdio.builder.templates.seismic_3d_poststack import Seismic3DPostStackTemplate
3637
from mdio.converters.mdio import mdio_to_segy
3738
from mdio.core import Dimension
3839
from mdio.creators.mdio import create_empty
3940

41+
UNITS_NONE = None
42+
UNITS_METER = LengthUnitModel(length=LengthUnitEnum.METER)
43+
UNITS_SECOND = TimeUnitModel(time=TimeUnitEnum.SECOND)
44+
UNITS_METER_PER_SECOND = SpeedUnitModel(speed=SpeedUnitEnum.METER_PER_SECOND)
45+
UNITS_FOOT = LengthUnitModel(length=LengthUnitEnum.FOOT)
46+
UNITS_FEET_PER_SECOND = SpeedUnitModel(speed=SpeedUnitEnum.FEET_PER_SECOND)
47+
48+
49+
class PostStack3DVelocityTemplate(Seismic3DPostStackTemplate):
50+
"""Custom template that uses 'velocity' as the default variable name instead of 'amplitude'."""
51+
52+
@property
53+
def _default_variable_name(self) -> str:
54+
"""Override the default variable name."""
55+
return "velocity"
56+
57+
def __init__(self, data_domain: str, is_metric: bool) -> None:
58+
super().__init__(data_domain)
59+
if is_metric:
60+
self._units.update(
61+
{
62+
"time": UNITS_SECOND,
63+
"cdp_x": UNITS_METER,
64+
"cdp_y": UNITS_METER,
65+
"velocity": UNITS_METER_PER_SECOND,
66+
}
67+
)
68+
else:
69+
self._units.update(
70+
{
71+
"time": UNITS_SECOND,
72+
"cdp_x": UNITS_FOOT,
73+
"cdp_y": UNITS_FOOT,
74+
"velocity": UNITS_FEET_PER_SECOND,
75+
}
76+
)
77+
78+
@property
79+
def _name(self) -> str:
80+
"""Override the name of the template."""
81+
domain_suffix = self._data_domain.capitalize()
82+
return f"PostStack3DVelocity{domain_suffix}"
83+
4084

4185
class TestCreateEmptyPostStack3DTimeMdio:
4286
"""Tests for create_empty_mdio function."""
@@ -58,35 +102,51 @@ def _get_customized_v10_trace_header_spec(cls) -> HeaderSpec:
58102
@classmethod
59103
def _validate_empty_mdio_dataset(cls, ds: xr_Dataset, has_headers: bool) -> None:
60104
"""Validate an empty MDIO dataset structure and content."""
105+
assert ds.name == "PostStack3DVelocityTime"
61106
# Check that the dataset has the expected shape
62107
assert ds.sizes == {"inline": 345, "crossline": 188, "time": 1501}
63108

64109
# Validate the dimension coordinate variables
65-
validate_variable(ds, "inline", (345,), ("inline",), np.int32, range(1, 346), get_values)
66-
validate_variable(ds, "crossline", (188,), ("crossline",), np.int32, range(1, 189), get_values)
67-
validate_variable(ds, "time", (1501,), ("time",), np.int32, range(0, 3002, 2), get_values)
110+
validate_xr_variable(ds, "inline", {"inline": 345}, UNITS_NONE, np.int32, range(1, 346), get_values)
111+
validate_xr_variable(ds, "crossline", {"crossline": 188}, UNITS_NONE, np.int32, range(1, 189), get_values)
112+
validate_xr_variable(ds, "time", {"time": 1501}, UNITS_SECOND, np.int32, range(0, 3002, 2), get_values)
68113

69114
# Validate the non-dimensional coordinate variables (should be empty for empty dataset)
70-
validate_variable(ds, "cdp_x", (345, 188), ("inline", "crossline"), np.float64, None, None)
71-
validate_variable(ds, "cdp_y", (345, 188), ("inline", "crossline"), np.float64, None, None)
115+
validate_xr_variable(ds, "cdp_x", {"inline": 345, "crossline": 188}, UNITS_METER, np.float64, None, None)
116+
validate_xr_variable(ds, "cdp_y", {"inline": 345, "crossline": 188}, UNITS_METER, np.float64, None, None)
72117

73118
if has_headers:
74119
# Validate the headers (should be empty for empty dataset)
75120
# Infer the dtype from segy_spec and ignore endianness
76121
header_dtype = cls._get_customized_v10_trace_header_spec().dtype.newbyteorder("native")
77-
validate_variable(ds, "headers", (345, 188), ("inline", "crossline"), header_dtype, None, None)
78-
validate_variable(ds, "segy_file_header", (), (), np.dtype("U1"), None, None)
122+
validate_xr_variable(ds, "headers", {"inline": 345, "crossline": 188}, UNITS_NONE, header_dtype, None, None)
123+
validate_xr_variable(
124+
ds,
125+
"segy_file_header",
126+
dims={},
127+
units=UNITS_NONE,
128+
data_type=np.dtype("U1"),
129+
expected_values=None,
130+
actual_value_generator=None,
131+
)
79132
else:
80133
assert "headers" not in ds.variables
81134
assert "segy_file_header" not in ds.variables
82-
83135
# Validate the trace mask (should be all True for empty dataset)
84-
validate_variable(ds, "trace_mask", (345, 188), ("inline", "crossline"), np.bool_, None, None)
136+
validate_xr_variable(ds, "trace_mask", {"inline": 345, "crossline": 188}, UNITS_NONE, np.bool_, None, None)
85137
trace_mask = ds["trace_mask"].values
86138
assert not np.any(trace_mask), "All traces should be marked as dead in empty dataset"
87139

88140
# Validate the amplitude data (should be empty)
89-
validate_variable(ds, "amplitude", (345, 188, 1501), ("inline", "crossline", "time"), np.float32, None, None)
141+
validate_xr_variable(
142+
ds,
143+
"velocity",
144+
{"inline": 345, "crossline": 188, "time": 1501},
145+
UNITS_METER_PER_SECOND,
146+
np.float32,
147+
None,
148+
None,
149+
)
90150

91151
@classmethod
92152
def _create_empty_mdio(cls, create_headers: bool, output_path: Path, overwrite: bool = True) -> None:
@@ -101,8 +161,10 @@ def _create_empty_mdio(cls, create_headers: bool, output_path: Path, overwrite:
101161
# If later on, we want to export to SEG-Y, we need to provide the trace header spec.
102162
# The HeaderSpec can be either standard or customized.
103163
headers = cls._get_customized_v10_trace_header_spec() if create_headers else None
164+
165+
# Create an empty MDIO v1 metric post-stack 3D time velocity dataset
104166
create_empty(
105-
mdio_template_name="PostStack3DTime",
167+
mdio_template=PostStack3DVelocityTemplate(data_domain="time", is_metric=True),
106168
dimensions=dims,
107169
output_path=output_path,
108170
headers=headers,
@@ -138,7 +200,7 @@ def test_dataset_metadata(self, mdio_with_headers: Path) -> None:
138200
# Check basic metadata attributes
139201
expected_attrs = {
140202
"apiVersion": __version__,
141-
"name": "PostStack3DTime",
203+
"name": "PostStack3DVelocityTime",
142204
}
143205
actual_attrs_json = ds.attrs
144206

@@ -159,7 +221,7 @@ def test_dataset_metadata(self, mdio_with_headers: Path) -> None:
159221
assert attributes is not None
160222
assert len(attributes) == 3
161223
# Validate all attributes provided by the abstract template
162-
assert attributes["defaultVariableName"] == "amplitude"
224+
assert attributes["defaultVariableName"] == "velocity"
163225
assert attributes["surveyType"] == "3D"
164226
assert attributes["gatherType"] == "stacked"
165227

@@ -203,9 +265,9 @@ def test_overwrite_behavior(self, empty_mdio_dir: Path) -> None:
203265

204266
def test_populate_empty_dataset(self, mdio_with_headers: Path) -> None:
205267
"""Test showing how to populate empty dataset."""
206-
# Open an empty PostStack3DTime dataset with SEG-Y 1.0 headers
268+
# Open an empty PostStack3DVelocityTime dataset with SEG-Y 1.0 headers
207269
# NOTES:
208-
# When this empty dataset was created from the 'PostStack3DTime' template and dimensions,
270+
# When this empty dataset was created from the 'PostStack3DVelocityTime' template and dimensions,
209271
# * 'inline', 'crossline', and 'time' dimension coordinate variables were created and pre-populated
210272
# * 'cdp_x', 'cdp_y' non-dimensional coordinate variables were created
211273
# * 'amplitude' variable was created (the name of this variable is specified in the template)

tests/integration/test_segy_roundtrip_teapot.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@
1414
from segy.standards import get_segy_standard
1515
from tests.integration.testing_helpers import get_inline_header_values
1616
from tests.integration.testing_helpers import get_values
17-
from tests.integration.testing_helpers import validate_variable
17+
from tests.integration.testing_helpers import validate_xr_variable
1818

1919
from mdio import __version__
2020
from mdio import mdio_to_segy
2121
from mdio.api.io import open_mdio
22+
from mdio.builder.schemas.v1.units import LengthUnitEnum
23+
from mdio.builder.schemas.v1.units import LengthUnitModel
24+
from mdio.builder.schemas.v1.units import TimeUnitEnum
25+
from mdio.builder.schemas.v1.units import TimeUnitModel
2226
from mdio.builder.template_registry import TemplateRegistry
2327
from mdio.converters.segy import segy_to_mdio
2428
from mdio.segy.file import SegyFileWrapper
@@ -148,6 +152,11 @@ def raw_binary_header_teapot_dome() -> str:
148152
)
149153

150154

155+
UNITS_NONE = None
156+
UNITS_METER = LengthUnitModel(length=LengthUnitEnum.METER)
157+
UNITS_SECOND = TimeUnitModel(time=TimeUnitEnum.SECOND)
158+
159+
151160
class TestTeapotRoundtrip:
152161
"""Tests for Teapot Dome data ingestion and export."""
153162

@@ -163,9 +172,13 @@ def test_teapot_import(
163172
164173
NOTE: This test must be executed before the 'TestReader' and 'TestExport' tests.
165174
"""
175+
unit_aware_template = TemplateRegistry().get("PostStack3DTime")
176+
unit_aware_template.add_units({"time": UNITS_SECOND})
177+
unit_aware_template.add_units({"cdp_x": UNITS_METER})
178+
unit_aware_template.add_units({"cdp_y": UNITS_METER})
166179
segy_to_mdio(
167180
segy_spec=teapot_segy_spec,
168-
mdio_template=TemplateRegistry().get("PostStack3DTime"),
181+
mdio_template=unit_aware_template,
169182
input_path=segy_input,
170183
output_path=zarr_tmp,
171184
overwrite=True,
@@ -224,38 +237,38 @@ def test_grid(self, zarr_tmp: Path, teapot_segy_spec: SegySpec) -> None:
224237
ds = open_mdio(zarr_tmp)
225238

226239
# Validate the dimension coordinate variables
227-
validate_variable(ds, "inline", (345,), ("inline",), np.int32, range(1, 346), get_values)
228-
validate_variable(ds, "crossline", (188,), ("crossline",), np.int32, range(1, 189), get_values)
229-
validate_variable(ds, "time", (1501,), ("time",), np.int32, range(0, 3002, 2), get_values)
240+
validate_xr_variable(ds, "inline", {"inline": 345}, UNITS_NONE, np.int32, range(1, 346), get_values)
241+
validate_xr_variable(ds, "crossline", {"crossline": 188}, UNITS_NONE, np.int32, range(1, 189), get_values)
242+
validate_xr_variable(ds, "time", {"time": 1501}, UNITS_SECOND, np.int32, range(0, 3002, 2), get_values)
230243

231244
# Validate the non-dimensional coordinate variables
232-
validate_variable(ds, "cdp_x", (345, 188), ("inline", "crossline"), np.float64, None, None)
233-
validate_variable(ds, "cdp_y", (345, 188), ("inline", "crossline"), np.float64, None, None)
245+
validate_xr_variable(ds, "cdp_x", {"inline": 345, "crossline": 188}, UNITS_METER, np.float64, None, None)
246+
validate_xr_variable(ds, "cdp_y", {"inline": 345, "crossline": 188}, UNITS_METER, np.float64, None, None)
234247

235248
# Validate the headers
236249
# We have a custom set of headers since we used customize_segy_specs()
237250
segy_spec = teapot_segy_spec
238251
data_type = segy_spec.trace.header.dtype
239252

240-
validate_variable(
253+
validate_xr_variable(
241254
ds,
242255
"headers",
243-
(345, 188),
244-
("inline", "crossline"),
256+
{"inline": 345, "crossline": 188},
257+
UNITS_NONE,
245258
data_type.newbyteorder("native"), # mdio saves with machine endian, spec could be different endian
246259
range(1, 346),
247260
get_inline_header_values,
248261
)
249262

250263
# Validate the trace mask
251-
validate_variable(ds, "trace_mask", (345, 188), ("inline", "crossline"), np.bool, None, None)
264+
validate_xr_variable(ds, "trace_mask", {"inline": 345, "crossline": 188}, UNITS_NONE, np.bool, None, None)
252265

253266
# validate the amplitude data
254-
validate_variable(
267+
validate_xr_variable(
255268
ds,
256269
"amplitude",
257-
(345, 188, 1501),
258-
("inline", "crossline", "time"),
270+
{"inline": 345, "crossline": 188, "time": 1501},
271+
UNITS_NONE,
259272
np.float32,
260273
None,
261274
None,

tests/integration/testing_helpers.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
import numpy as np
66
import xarray as xr
7-
from numpy.typing import DTypeLike
7+
8+
from mdio.builder.schemas.v1.units import AllUnitModel
89

910

1011
def get_values(arr: xr.DataArray) -> np.ndarray:
@@ -17,35 +18,40 @@ def get_inline_header_values(dataset: xr.Dataset) -> np.ndarray:
1718
return dataset["inline"].values
1819

1920

20-
def validate_variable( # noqa PLR0913
21+
def validate_xr_variable( # noqa PLR0913
2122
dataset: xr.Dataset,
2223
name: str,
23-
shape: tuple[int, ...],
24-
dims: tuple[str, ...],
25-
data_type: DTypeLike,
24+
dims: dict[int],
25+
units: AllUnitModel,
26+
data_type: np.dtype,
2627
expected_values: range | None,
2728
actual_value_generator: Callable[[xr.DataArray], np.ndarray] | None = None,
2829
) -> None:
2930
"""Validate the properties of a variable in an Xarray dataset."""
30-
arr = dataset[name]
31-
assert shape == arr.shape
32-
assert set(dims) == set(arr.dims)
31+
v = dataset[name]
32+
assert v is not None
33+
assert v.sizes == dims
3334
if hasattr(data_type, "fields") and data_type.fields is not None:
3435
# The following assertion will fail because of differences in offsets
3536
# assert data_type == arr.dtype
3637

3738
# Compare field names
3839
expected_names = list(data_type.names)
39-
actual_names = list(arr.dtype.names)
40+
actual_names = list(v.dtype.names)
4041
assert expected_names == actual_names
4142

4243
# Compare field types
4344
expected_types = [data_type[name] for name in data_type.names]
44-
actual_types = [arr.dtype[name] for name in arr.dtype.names]
45+
actual_types = [v.dtype[name] for name in v.dtype.names]
4546
assert expected_types == actual_types
4647
else:
47-
assert data_type == arr.dtype
48+
assert data_type == v.dtype
49+
50+
if units is not None:
51+
assert v.attrs == {"unitsV1": units.model_dump(mode="json")}
52+
else:
53+
assert "unitsV1" not in v.attrs
4854

4955
if expected_values is not None and actual_value_generator is not None:
50-
actual_values = actual_value_generator(arr)
56+
actual_values = actual_value_generator(v)
5157
assert np.array_equal(expected_values, actual_values)

0 commit comments

Comments
 (0)