Skip to content

Commit 07833de

Browse files
committed
Implement Pydantic grid override definitions
1 parent 51ae74b commit 07833de

7 files changed

Lines changed: 263 additions & 46 deletions

File tree

src/mdio/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from mdio.converters import segy_to_mdio
1111
from mdio.optimize.access_pattern import OptimizedAccessPatternConfig
1212
from mdio.optimize.access_pattern import optimize_access_patterns
13+
from mdio.segy.geometry import GridOverrides
1314

1415
try:
1516
__version__ = metadata.version("multidimio")
@@ -19,6 +20,7 @@
1920

2021
__all__ = [
2122
"__version__",
23+
"GridOverrides",
2224
"open_mdio",
2325
"to_mdio",
2426
"mdio_to_segy",

src/mdio/converters/segy.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from mdio.ingestion.segy.validation import _validate_spec_in_template
3838
from mdio.segy import blocked_io
3939
from mdio.segy.file import get_segy_file_info
40+
from mdio.segy.geometry import GridOverrides
4041
from mdio.segy.utilities import get_grid_plan
4142

4243
if TYPE_CHECKING:
@@ -128,7 +129,7 @@ def filtered_add_coordinate( # noqa: ANN202
128129

129130
def _update_template_from_grid_overrides(
130131
template: AbstractDatasetTemplate,
131-
grid_overrides: dict[str, Any] | None,
132+
grid_overrides: GridOverrides | None,
132133
segy_dimensions: list[Dimension],
133134
full_chunk_shape: tuple[int, ...],
134135
chunk_size: tuple[int, ...],
@@ -178,30 +179,29 @@ def _update_template_from_grid_overrides(
178179

179180
# If using NonBinned override, expose non-binned dims as logical coordinates on the template instance
180181
# and patch _add_coordinates to skip adding them as 1D dimension coordinates
181-
if grid_overrides and "NonBinned" in grid_overrides and "non_binned_dims" in grid_overrides:
182-
non_binned_dims = tuple(grid_overrides["non_binned_dims"])
183-
if non_binned_dims:
184-
logger.debug(
185-
"NonBinned grid override: exposing non-binned dims as coordinates: %s",
186-
non_binned_dims,
187-
)
188-
# Append any missing names; keep existing order and avoid duplicates
189-
existing = set(template.coordinate_names)
190-
to_add = tuple(n for n in non_binned_dims if n not in existing)
191-
if to_add:
192-
template._logical_coord_names = template._logical_coord_names + to_add
193-
194-
# Patch _add_coordinates to skip adding non-binned dims as 1D dimension coordinates
195-
# This prevents them from being added with wrong dimensions (e.g., just "trace")
196-
# They will be added later by build_dataset with full spatial_dimension_names
197-
_patch_add_coordinates_for_non_binned(template, set(non_binned_dims))
182+
if grid_overrides is not None and grid_overrides.non_binned and grid_overrides.non_binned_dims:
183+
non_binned_dims = tuple(grid_overrides.non_binned_dims)
184+
logger.debug(
185+
"NonBinned grid override: exposing non-binned dims as coordinates: %s",
186+
non_binned_dims,
187+
)
188+
# Append any missing names; keep existing order and avoid duplicates
189+
existing = set(template.coordinate_names)
190+
to_add = tuple(n for n in non_binned_dims if n not in existing)
191+
if to_add:
192+
template._logical_coord_names = template._logical_coord_names + to_add
193+
194+
# Patch _add_coordinates to skip adding non-binned dims as 1D dimension coordinates
195+
# This prevents them from being added with wrong dimensions (e.g., just "trace")
196+
# They will be added later by build_dataset with full spatial_dimension_names
197+
_patch_add_coordinates_for_non_binned(template, set(non_binned_dims))
198198

199199

200200
def _scan_for_headers(
201201
segy_file_kwargs: SegyFileArguments,
202202
segy_file_info: SegyFileInfo,
203203
template: AbstractDatasetTemplate,
204-
grid_overrides: dict[str, Any] | None = None,
204+
grid_overrides: GridOverrides | None = None,
205205
) -> tuple[list[Dimension], SegyHeaderArray]:
206206
"""Extract trace dimensions and index headers from the SEG-Y file.
207207
@@ -346,13 +346,34 @@ def determine_target_size(var_type: str) -> int:
346346
ds.variables[index].metadata.chunk_grid = chunk_grid
347347

348348

349+
def _coerce_grid_overrides(
350+
grid_overrides: GridOverrides | dict[str, Any] | None,
351+
) -> GridOverrides | None:
352+
"""Normalize public ``grid_overrides`` input into a :class:`GridOverrides` model.
353+
354+
The internal ingestion pipeline only accepts the typed model. A legacy ``dict`` is
355+
converted via :meth:`GridOverrides.from_legacy_dict` and a deprecation message is logged.
356+
"""
357+
if grid_overrides is None:
358+
return None
359+
360+
if isinstance(grid_overrides, GridOverrides):
361+
return grid_overrides
362+
363+
logger.warning(
364+
"Passing `grid_overrides` as a dict is deprecated and will be removed in a "
365+
"future release; pass a `mdio.GridOverrides` instance instead."
366+
)
367+
return GridOverrides.model_validate(grid_overrides)
368+
369+
349370
def segy_to_mdio( # noqa PLR0913
350371
segy_spec: SegySpec,
351372
mdio_template: AbstractDatasetTemplate,
352373
input_path: UPath | Path | str,
353374
output_path: UPath | Path | str,
354375
overwrite: bool = False,
355-
grid_overrides: dict[str, Any] | None = None,
376+
grid_overrides: GridOverrides | dict[str, Any] | None = None,
356377
segy_header_overrides: SegyHeaderOverrides | None = None,
357378
) -> None:
358379
"""A function that converts a SEG-Y file to an MDIO v1 file.
@@ -365,12 +386,15 @@ def segy_to_mdio( # noqa PLR0913
365386
input_path: The universal path of the input SEG-Y file.
366387
output_path: The universal path for the output MDIO v1 file.
367388
overwrite: Whether to overwrite the output file if it already exists. Defaults to False.
368-
grid_overrides: Option to add grid overrides.
389+
grid_overrides: Option to add grid overrides. Prefer a :class:`mdio.GridOverrides`
390+
instance; ``dict`` is still accepted but emits a :class:`DeprecationWarning`.
369391
segy_header_overrides: Option to override specific SEG-Y headers during ingestion.
370392
371393
Raises:
372394
FileExistsError: If the output location already exists and overwrite is False.
373395
"""
396+
typed_grid_overrides = _coerce_grid_overrides(grid_overrides)
397+
374398
settings = MDIOSettings()
375399

376400
_validate_spec_in_template(segy_spec, mdio_template)
@@ -395,7 +419,7 @@ def segy_to_mdio( # noqa PLR0913
395419
segy_file_kwargs,
396420
segy_file_info,
397421
template=mdio_template,
398-
grid_overrides=grid_overrides,
422+
grid_overrides=typed_grid_overrides,
399423
)
400424
grid = _build_and_check_grid(segy_dimensions, segy_file_info, segy_headers)
401425

@@ -417,7 +441,7 @@ def segy_to_mdio( # noqa PLR0913
417441
mdio_template = _update_template_units(mdio_template, spatial_unit)
418442
mdio_ds: Dataset = mdio_template.build_dataset(name=mdio_template.name, sizes=grid.shape, header_dtype=header_dtype)
419443

420-
_add_grid_override_to_metadata(dataset=mdio_ds, grid_overrides=grid_overrides)
444+
_add_grid_override_to_metadata(dataset=mdio_ds, grid_overrides=typed_grid_overrides)
421445

422446
# Dynamically chunk the variables based on their type
423447
_chunk_variable(ds=mdio_ds, target_variable_name="trace_mask") # trace_mask is a Variable and not a Coordinate

src/mdio/ingestion/metadata.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
from __future__ import annotations
44

55
from typing import TYPE_CHECKING
6-
from typing import Any
76

87
if TYPE_CHECKING:
98
from mdio.builder.schemas import Dataset
9+
from mdio.segy.geometry import GridOverrides
1010

1111

12-
def _add_grid_override_to_metadata(dataset: Dataset, grid_overrides: dict[str, Any] | None) -> None:
12+
def _add_grid_override_to_metadata(dataset: Dataset, grid_overrides: GridOverrides | None) -> None:
1313
"""Add grid override to Dataset metadata if needed."""
1414
if dataset.metadata.attributes is None:
1515
dataset.metadata.attributes = {}
1616

1717
if grid_overrides is not None:
18-
dataset.metadata.attributes["gridOverrides"] = grid_overrides
18+
dataset.metadata.attributes["gridOverrides"] = grid_overrides.to_legacy_dict()

src/mdio/segy/geometry.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@
66
from abc import ABC
77
from abc import abstractmethod
88
from typing import TYPE_CHECKING
9+
from typing import Any
910

1011
import numpy as np
1112
from numpy.lib import recfunctions as rfn
13+
from pydantic import BaseModel
14+
from pydantic import ConfigDict
15+
from pydantic import Field
1216

1317
from mdio.ingestion.segy.header_analysis import ShotGunGeometryType
1418
from mdio.ingestion.segy.header_analysis import StreamerShotGeometryType
@@ -31,6 +35,61 @@
3135
logger = logging.getLogger(__name__)
3236

3337

38+
class GridOverrides(BaseModel):
39+
"""Type-safe configuration for grid override operations during SEG-Y ingestion."""
40+
41+
model_config = ConfigDict(extra="forbid", validate_by_name=True)
42+
43+
auto_channel_wrap: bool = Field(
44+
default=False,
45+
alias="AutoChannelWrap",
46+
description="Streamer: auto-detect channel-wrap geometry (Type A vs B).",
47+
)
48+
auto_shot_wrap: bool = Field(
49+
default=False,
50+
alias="AutoShotWrap",
51+
description="Streamer: derive dense shot_index from interleaved shot_point values.",
52+
)
53+
calculate_shot_index: bool = Field(
54+
default=False,
55+
alias="CalculateShotIndex",
56+
description="OBN: derive dense shot_index from sparse shot_point values per shot_line.",
57+
)
58+
non_binned: bool = Field(
59+
default=False,
60+
alias="NonBinned",
61+
description="Collapse selected dims into a single trace dimension without spatial binning.",
62+
)
63+
has_duplicates: bool = Field(
64+
default=False,
65+
alias="HasDuplicates",
66+
description="Add a trace dimension (chunksize 1) to disambiguate duplicate trace indices.",
67+
)
68+
chunksize: int | None = Field(
69+
default=None,
70+
gt=0,
71+
description="Chunk size for the trace dimension when `non_binned` is True.",
72+
)
73+
non_binned_dims: list[str] | None = Field(
74+
default=None,
75+
description="Dimension names to collapse into the trace dimension when `non_binned` is True.",
76+
)
77+
78+
def __bool__(self) -> bool:
79+
"""Return True if any override flag is enabled."""
80+
return (
81+
self.auto_channel_wrap
82+
or self.auto_shot_wrap
83+
or self.calculate_shot_index
84+
or self.non_binned
85+
or self.has_duplicates
86+
)
87+
88+
def to_legacy_dict(self) -> dict[str, Any]:
89+
"""Dump to the legacy ``CamelCase`` dict shape consumed by :class:`GridOverrider`."""
90+
return self.model_dump(by_alias=True, exclude_defaults=True)
91+
92+
3493
class GridOverrideCommand(ABC):
3594
"""Abstract base class for grid override commands."""
3695

src/mdio/segy/utilities.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import itertools
66
import logging
77
from typing import TYPE_CHECKING
8-
from typing import Any
98

109
import numpy as np
1110
from dask.array.core import normalize_chunks
@@ -24,6 +23,7 @@
2423
from mdio.builder.templates.base import AbstractDatasetTemplate
2524
from mdio.segy.file import SegyFileArguments
2625
from mdio.segy.file import SegyFileInfo
26+
from mdio.segy.geometry import GridOverrides
2727

2828
logger = logging.getLogger(__name__)
2929

@@ -34,7 +34,7 @@ def get_grid_plan( # noqa: C901, PLR0912, PLR0913, PLR0915
3434
chunksize: tuple[int, ...] | None,
3535
template: AbstractDatasetTemplate,
3636
return_headers: bool = False,
37-
grid_overrides: dict[str, Any] | None = None,
37+
grid_overrides: GridOverrides | None = None,
3838
) -> tuple[list[Dimension], tuple[int, ...]] | tuple[list[Dimension], tuple[int, ...], HeaderArray]:
3939
"""Infer dimension ranges, and increments.
4040
@@ -50,17 +50,14 @@ def get_grid_plan( # noqa: C901, PLR0912, PLR0913, PLR0915
5050
chunksize: Chunk sizes to be used in grid plan.
5151
template: MDIO template where coordinate names and domain will be taken.
5252
return_headers: Option to return parsed headers with `Dimension` objects. Default is False.
53-
grid_overrides: Option to add grid overrides. See main documentation.
53+
grid_overrides: Typed grid override configuration, or ``None`` for no overrides.
5454
5555
Returns:
5656
All index dimensions and chunksize or dimensions and chunksize together with header values.
5757
5858
Raises:
5959
ValueError: If computed fields are not found after grid overrides.
6060
"""
61-
if grid_overrides is None:
62-
grid_overrides = {}
63-
6461
# Keep only dimension and non-dimension coordinates excluding the vertical axis
6562
horizontal_dimensions = template.spatial_dimension_names
6663
horizontal_coordinates = horizontal_dimensions + template.coordinate_names
@@ -72,8 +69,8 @@ def get_grid_plan( # noqa: C901, PLR0912, PLR0913, PLR0915
7269
horizontal_coordinates = tuple(c for c in horizontal_coordinates if c not in computed_fields)
7370

7471
# Ensure non_binned_dims are included in the headers to parse, even if not in template
75-
if grid_overrides and "non_binned_dims" in grid_overrides:
76-
for dim in grid_overrides["non_binned_dims"]:
72+
if grid_overrides is not None and grid_overrides.non_binned_dims:
73+
for dim in grid_overrides.non_binned_dims:
7774
if dim not in horizontal_coordinates:
7875
horizontal_coordinates = horizontal_coordinates + (dim,)
7976

@@ -94,20 +91,24 @@ def get_grid_plan( # noqa: C901, PLR0912, PLR0913, PLR0915
9491
subset=tuple(c for c in horizontal_coordinates if c not in fields_to_skip),
9592
)
9693

97-
# Handle grid overrides.
94+
# The legacy GridOverrider still consumes the dict shape; dump only at this boundary.
95+
# Future PR will replace GridOverrider.
9896
override_handler = GridOverrider()
9997
headers_subset, horizontal_coordinates, chunksize = override_handler.run(
10098
headers_subset,
10199
horizontal_coordinates,
102100
chunksize=chunksize,
103-
grid_overrides=grid_overrides,
101+
grid_overrides=grid_overrides.to_legacy_dict() if grid_overrides is not None else {},
104102
template=template,
105103
)
106104

107105
# After grid overrides, determine final spatial dimensions and their chunk sizes
108-
non_binned_dims = set()
109-
if "NonBinned" in grid_overrides and "non_binned_dims" in grid_overrides:
110-
non_binned_dims = set(grid_overrides["non_binned_dims"])
106+
non_binned_active = grid_overrides is not None and grid_overrides.non_binned
107+
non_binned_dims: set[str] = (
108+
set(grid_overrides.non_binned_dims)
109+
if non_binned_active and grid_overrides is not None and grid_overrides.non_binned_dims
110+
else set()
111+
)
111112

112113
# Create mapping from dimension name to original chunk size for easy lookup
113114
original_spatial_dims = list(template.spatial_dimension_names)
@@ -121,8 +122,11 @@ def get_grid_plan( # noqa: C901, PLR0912, PLR0913, PLR0915
121122
if name in non_binned_dims:
122123
continue # Skip dimensions that became coordinates
123124
if name == "trace":
124-
# Special handling for trace dimension
125-
chunk_val = int(grid_overrides.get("chunksize", 1)) if "NonBinned" in grid_overrides else 1
125+
chunk_val = (
126+
int(grid_overrides.chunksize)
127+
if non_binned_active and grid_overrides is not None and grid_overrides.chunksize is not None
128+
else 1
129+
)
126130
final_spatial_dims.append(name)
127131
final_spatial_chunks.append(chunk_val)
128132
elif name in dim_to_chunk:

tests/unit/ingestion/test_metadata.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from types import SimpleNamespace
66

77
from mdio.ingestion.metadata import _add_grid_override_to_metadata
8+
from mdio.segy.geometry import GridOverrides
89

910

1011
def _make_dataset(attributes: dict | None) -> SimpleNamespace:
@@ -22,18 +23,23 @@ def test_initializes_attributes_dict_when_none(self) -> None:
2223
assert dataset.metadata.attributes == {}
2324

2425
def test_adds_grid_overrides_when_provided(self) -> None:
25-
"""Grid overrides should land under the ``gridOverrides`` key."""
26+
"""Active grid overrides should serialize under the ``gridOverrides`` key."""
2627
dataset = _make_dataset(attributes=None)
27-
overrides = {"HasDuplicates": True, "chunksize": 4}
28+
overrides = GridOverrides(has_duplicates=True, chunksize=4)
2829
_add_grid_override_to_metadata(dataset, grid_overrides=overrides)
29-
assert dataset.metadata.attributes == {"gridOverrides": overrides}
30+
assert dataset.metadata.attributes == {
31+
"gridOverrides": {"HasDuplicates": True, "chunksize": 4},
32+
}
3033

3134
def test_preserves_existing_attributes(self) -> None:
3235
"""Existing attribute keys should be preserved when adding overrides."""
3336
dataset = _make_dataset(attributes={"existing": "value"})
34-
overrides = {"NonBinned": True}
37+
overrides = GridOverrides(non_binned=True)
3538
_add_grid_override_to_metadata(dataset, grid_overrides=overrides)
36-
assert dataset.metadata.attributes == {"existing": "value", "gridOverrides": overrides}
39+
assert dataset.metadata.attributes == {
40+
"existing": "value",
41+
"gridOverrides": {"NonBinned": True},
42+
}
3743

3844
def test_no_overrides_leaves_attributes_untouched(self) -> None:
3945
"""Passing ``None`` overrides must not introduce a ``gridOverrides`` key."""

0 commit comments

Comments
 (0)