Skip to content

Commit 00d3d7d

Browse files
authored
Implement Pydantic grid override definitions (#820)
* Implement Pydantic grid override definitions * Update docs
1 parent 51ae74b commit 00d3d7d

10 files changed

Lines changed: 299 additions & 48 deletions

File tree

docs/api_reference.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ and
3232
:members:
3333
```
3434

35+
### Grid Overrides
36+
37+
```{eval-rst}
38+
.. autopydantic_model:: mdio.GridOverrides
39+
```
40+
3541
## Core Functionality
3642

3743
### Dimensions

docs/guides/grid_overrides.md

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,30 @@ Grid overrides are transformations applied during SEG-Y import that modify how t
1010

1111
When importing SEG-Y data, MDIO maps trace header fields to dataset dimensions. However, real-world seismic data often has complexities that require additional processing. Grid overrides address these issues by transforming header values before indexing.
1212

13+
## Configuring grid overrides
14+
15+
Grid overrides are passed to {func}`mdio.segy_to_mdio` via the `grid_overrides` argument as an
16+
{class}`mdio.GridOverrides` instance:
17+
18+
```python
19+
from mdio import GridOverrides
20+
from mdio import segy_to_mdio
21+
22+
segy_to_mdio(
23+
...,
24+
grid_overrides=GridOverrides(calculate_shot_index=True),
25+
)
26+
```
27+
28+
Both modern `snake_case` field names and the legacy `CamelCase` aliases are accepted, so
29+
`GridOverrides(CalculateShotIndex=True)` is equivalent to the example above. Unknown keys
30+
are rejected at construction with a `pydantic.ValidationError`.
31+
32+
```{deprecated} 1.2
33+
Passing `grid_overrides` as a `dict` still works but logs a deprecation warning and will be
34+
removed in a future release. Switch to `mdio.GridOverrides`.
35+
```
36+
1337
## CalculateShotIndex
1438

1539
Calculates a dense `shot_index` dimension from sparse or interleaved `shot_point` values. Required for the `ObnReceiverGathers3D` template.
@@ -37,12 +61,15 @@ The override detects the geometry type and only applies the transformation when
3761
**Usage:**
3862

3963
```python
64+
from mdio import GridOverrides
65+
from mdio import segy_to_mdio
66+
4067
segy_to_mdio(
4168
input_path="obn_data.sgy",
4269
output_path="obn_data.mdio",
4370
segy_spec=obn_spec,
4471
mdio_template=get_template("ObnReceiverGathers3D"),
45-
grid_overrides={"CalculateShotIndex": True},
72+
grid_overrides=GridOverrides(calculate_shot_index=True),
4673
)
4774
```
4875

docs/guides/obn_data_import.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ A warning is logged when component is synthesized:
6666
from segy.schema import HeaderField
6767
from segy.standards import get_segy_standard
6868

69+
from mdio import GridOverrides
6970
from mdio import segy_to_mdio
7071
from mdio.builder.template_registry import get_template
7172

@@ -91,7 +92,7 @@ segy_to_mdio(
9192
output_path="obn_data.mdio",
9293
segy_spec=obn_spec,
9394
mdio_template=get_template("ObnReceiverGathers3D"),
94-
grid_overrides={"CalculateShotIndex": True},
95+
grid_overrides=GridOverrides(calculate_shot_index=True),
9596
overwrite=True,
9697
)
9798
```

src/mdio/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from mdio.converters import segy_to_mdio
1111
from mdio.optimize.access_pattern import OptimizedAccessPatternConfig
1212
from mdio.optimize.access_pattern import optimize_access_patterns
13+
from mdio.segy.geometry import GridOverrides
1314

1415
try:
1516
__version__ = metadata.version("multidimio")
@@ -19,6 +20,7 @@
1920

2021
__all__ = [
2122
"__version__",
23+
"GridOverrides",
2224
"open_mdio",
2325
"to_mdio",
2426
"mdio_to_segy",

src/mdio/converters/segy.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from mdio.ingestion.segy.validation import _validate_spec_in_template
3838
from mdio.segy import blocked_io
3939
from mdio.segy.file import get_segy_file_info
40+
from mdio.segy.geometry import GridOverrides
4041
from mdio.segy.utilities import get_grid_plan
4142

4243
if TYPE_CHECKING:
@@ -128,7 +129,7 @@ def filtered_add_coordinate( # noqa: ANN202
128129

129130
def _update_template_from_grid_overrides(
130131
template: AbstractDatasetTemplate,
131-
grid_overrides: dict[str, Any] | None,
132+
grid_overrides: GridOverrides | None,
132133
segy_dimensions: list[Dimension],
133134
full_chunk_shape: tuple[int, ...],
134135
chunk_size: tuple[int, ...],
@@ -178,30 +179,29 @@ def _update_template_from_grid_overrides(
178179

179180
# If using NonBinned override, expose non-binned dims as logical coordinates on the template instance
180181
# and patch _add_coordinates to skip adding them as 1D dimension coordinates
181-
if grid_overrides and "NonBinned" in grid_overrides and "non_binned_dims" in grid_overrides:
182-
non_binned_dims = tuple(grid_overrides["non_binned_dims"])
183-
if non_binned_dims:
184-
logger.debug(
185-
"NonBinned grid override: exposing non-binned dims as coordinates: %s",
186-
non_binned_dims,
187-
)
188-
# Append any missing names; keep existing order and avoid duplicates
189-
existing = set(template.coordinate_names)
190-
to_add = tuple(n for n in non_binned_dims if n not in existing)
191-
if to_add:
192-
template._logical_coord_names = template._logical_coord_names + to_add
193-
194-
# Patch _add_coordinates to skip adding non-binned dims as 1D dimension coordinates
195-
# This prevents them from being added with wrong dimensions (e.g., just "trace")
196-
# They will be added later by build_dataset with full spatial_dimension_names
197-
_patch_add_coordinates_for_non_binned(template, set(non_binned_dims))
182+
if grid_overrides is not None and grid_overrides.non_binned and grid_overrides.non_binned_dims:
183+
non_binned_dims = tuple(grid_overrides.non_binned_dims)
184+
logger.debug(
185+
"NonBinned grid override: exposing non-binned dims as coordinates: %s",
186+
non_binned_dims,
187+
)
188+
# Append any missing names; keep existing order and avoid duplicates
189+
existing = set(template.coordinate_names)
190+
to_add = tuple(n for n in non_binned_dims if n not in existing)
191+
if to_add:
192+
template._logical_coord_names = template._logical_coord_names + to_add
193+
194+
# Patch _add_coordinates to skip adding non-binned dims as 1D dimension coordinates
195+
# This prevents them from being added with wrong dimensions (e.g., just "trace")
196+
# They will be added later by build_dataset with full spatial_dimension_names
197+
_patch_add_coordinates_for_non_binned(template, set(non_binned_dims))
198198

199199

200200
def _scan_for_headers(
201201
segy_file_kwargs: SegyFileArguments,
202202
segy_file_info: SegyFileInfo,
203203
template: AbstractDatasetTemplate,
204-
grid_overrides: dict[str, Any] | None = None,
204+
grid_overrides: GridOverrides | None = None,
205205
) -> tuple[list[Dimension], SegyHeaderArray]:
206206
"""Extract trace dimensions and index headers from the SEG-Y file.
207207
@@ -346,13 +346,34 @@ def determine_target_size(var_type: str) -> int:
346346
ds.variables[index].metadata.chunk_grid = chunk_grid
347347

348348

349+
def _coerce_grid_overrides(
350+
grid_overrides: GridOverrides | dict[str, Any] | None,
351+
) -> GridOverrides | None:
352+
"""Normalize public ``grid_overrides`` input into a :class:`GridOverrides` model.
353+
354+
The internal ingestion pipeline only accepts the typed model. A legacy ``dict`` is
355+
converted via :meth:`GridOverrides.from_legacy_dict` and a deprecation message is logged.
356+
"""
357+
if grid_overrides is None:
358+
return None
359+
360+
if isinstance(grid_overrides, GridOverrides):
361+
return grid_overrides
362+
363+
logger.warning(
364+
"Passing `grid_overrides` as a dict is deprecated and will be removed in a "
365+
"future release; pass a `mdio.GridOverrides` instance instead."
366+
)
367+
return GridOverrides.model_validate(grid_overrides)
368+
369+
349370
def segy_to_mdio( # noqa PLR0913
350371
segy_spec: SegySpec,
351372
mdio_template: AbstractDatasetTemplate,
352373
input_path: UPath | Path | str,
353374
output_path: UPath | Path | str,
354375
overwrite: bool = False,
355-
grid_overrides: dict[str, Any] | None = None,
376+
grid_overrides: GridOverrides | dict[str, Any] | None = None,
356377
segy_header_overrides: SegyHeaderOverrides | None = None,
357378
) -> None:
358379
"""A function that converts a SEG-Y file to an MDIO v1 file.
@@ -365,12 +386,15 @@ def segy_to_mdio( # noqa PLR0913
365386
input_path: The universal path of the input SEG-Y file.
366387
output_path: The universal path for the output MDIO v1 file.
367388
overwrite: Whether to overwrite the output file if it already exists. Defaults to False.
368-
grid_overrides: Option to add grid overrides.
389+
grid_overrides: Option to add grid overrides. Prefer a :class:`mdio.GridOverrides`
390+
instance; ``dict`` is still accepted but emits a :class:`DeprecationWarning`.
369391
segy_header_overrides: Option to override specific SEG-Y headers during ingestion.
370392
371393
Raises:
372394
FileExistsError: If the output location already exists and overwrite is False.
373395
"""
396+
typed_grid_overrides = _coerce_grid_overrides(grid_overrides)
397+
374398
settings = MDIOSettings()
375399

376400
_validate_spec_in_template(segy_spec, mdio_template)
@@ -395,7 +419,7 @@ def segy_to_mdio( # noqa PLR0913
395419
segy_file_kwargs,
396420
segy_file_info,
397421
template=mdio_template,
398-
grid_overrides=grid_overrides,
422+
grid_overrides=typed_grid_overrides,
399423
)
400424
grid = _build_and_check_grid(segy_dimensions, segy_file_info, segy_headers)
401425

@@ -417,7 +441,7 @@ def segy_to_mdio( # noqa PLR0913
417441
mdio_template = _update_template_units(mdio_template, spatial_unit)
418442
mdio_ds: Dataset = mdio_template.build_dataset(name=mdio_template.name, sizes=grid.shape, header_dtype=header_dtype)
419443

420-
_add_grid_override_to_metadata(dataset=mdio_ds, grid_overrides=grid_overrides)
444+
_add_grid_override_to_metadata(dataset=mdio_ds, grid_overrides=typed_grid_overrides)
421445

422446
# Dynamically chunk the variables based on their type
423447
_chunk_variable(ds=mdio_ds, target_variable_name="trace_mask") # trace_mask is a Variable and not a Coordinate

src/mdio/ingestion/metadata.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
from __future__ import annotations
44

55
from typing import TYPE_CHECKING
6-
from typing import Any
76

87
if TYPE_CHECKING:
98
from mdio.builder.schemas import Dataset
9+
from mdio.segy.geometry import GridOverrides
1010

1111

12-
def _add_grid_override_to_metadata(dataset: Dataset, grid_overrides: dict[str, Any] | None) -> None:
12+
def _add_grid_override_to_metadata(dataset: Dataset, grid_overrides: GridOverrides | None) -> None:
1313
"""Add grid override to Dataset metadata if needed."""
1414
if dataset.metadata.attributes is None:
1515
dataset.metadata.attributes = {}
1616

1717
if grid_overrides is not None:
18-
dataset.metadata.attributes["gridOverrides"] = grid_overrides
18+
dataset.metadata.attributes["gridOverrides"] = grid_overrides.to_legacy_dict()

src/mdio/segy/geometry.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@
66
from abc import ABC
77
from abc import abstractmethod
88
from typing import TYPE_CHECKING
9+
from typing import Any
910

1011
import numpy as np
1112
from numpy.lib import recfunctions as rfn
13+
from pydantic import BaseModel
14+
from pydantic import ConfigDict
15+
from pydantic import Field
1216

1317
from mdio.ingestion.segy.header_analysis import ShotGunGeometryType
1418
from mdio.ingestion.segy.header_analysis import StreamerShotGeometryType
@@ -31,6 +35,61 @@
3135
logger = logging.getLogger(__name__)
3236

3337

38+
class GridOverrides(BaseModel):
39+
"""Type-safe configuration for grid override operations during SEG-Y ingestion."""
40+
41+
model_config = ConfigDict(extra="forbid", validate_by_name=True)
42+
43+
auto_channel_wrap: bool = Field(
44+
default=False,
45+
alias="AutoChannelWrap",
46+
description="Streamer: auto-detect channel-wrap geometry (Type A vs B).",
47+
)
48+
auto_shot_wrap: bool = Field(
49+
default=False,
50+
alias="AutoShotWrap",
51+
description="Streamer: derive dense shot_index from interleaved shot_point values.",
52+
)
53+
calculate_shot_index: bool = Field(
54+
default=False,
55+
alias="CalculateShotIndex",
56+
description="OBN: derive dense shot_index from sparse shot_point values per shot_line.",
57+
)
58+
non_binned: bool = Field(
59+
default=False,
60+
alias="NonBinned",
61+
description="Collapse selected dims into a single trace dimension without spatial binning.",
62+
)
63+
has_duplicates: bool = Field(
64+
default=False,
65+
alias="HasDuplicates",
66+
description="Add a trace dimension (chunksize 1) to disambiguate duplicate trace indices.",
67+
)
68+
chunksize: int | None = Field(
69+
default=None,
70+
gt=0,
71+
description="Chunk size for the trace dimension when `non_binned` is True.",
72+
)
73+
non_binned_dims: list[str] | None = Field(
74+
default=None,
75+
description="Dimension names to collapse into the trace dimension when `non_binned` is True.",
76+
)
77+
78+
def __bool__(self) -> bool:
79+
"""Return True if any override flag is enabled."""
80+
return (
81+
self.auto_channel_wrap
82+
or self.auto_shot_wrap
83+
or self.calculate_shot_index
84+
or self.non_binned
85+
or self.has_duplicates
86+
)
87+
88+
def to_legacy_dict(self) -> dict[str, Any]:
89+
"""Dump to the legacy ``CamelCase`` dict shape consumed by :class:`GridOverrider`."""
90+
return self.model_dump(by_alias=True, exclude_defaults=True)
91+
92+
3493
class GridOverrideCommand(ABC):
3594
"""Abstract base class for grid override commands."""
3695

0 commit comments

Comments
 (0)