Skip to content

Commit 57024d6

Browse files
authored
Merge pull request #35 from d-v-b/chore/types
chore: types
2 parents 23db304 + 15c9227 commit 57024d6

2 files changed

Lines changed: 172 additions & 66 deletions

File tree

src/eopf_geozarr/conversion/geozarr.py

Lines changed: 77 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
import os
1919
import shutil
2020
import time
21-
from typing import Any, Dict, Hashable, List, Optional, Tuple
21+
from collections.abc import Hashable, Iterable, Mapping
22+
from typing import Any, Dict, List, Tuple
2223

2324
import numpy as np
2425
import xarray as xr
@@ -30,21 +31,32 @@
3031
from zarr.storage import StoreLike
3132
from zarr.storage._common import make_store_path
3233

34+
from eopf_geozarr.types import (
35+
OverviewLevelJSON,
36+
StandardLatCoordAttrsJSON,
37+
StandardLonCoordAttrsJSON,
38+
StandardXCoordAttrsJSON,
39+
StandardYCoordAttrsJSON,
40+
TileMatrixJSON,
41+
TileMatrixLimitJSON,
42+
TileMatrixSetJSON,
43+
XarrayEncodingJSON,
44+
)
45+
3346
from . import fs_utils, utils
3447
from .sentinel1_reprojection import reproject_sentinel1_with_gcps
3548

3649

3750
def create_geozarr_dataset(
3851
dt_input: xr.DataTree,
39-
*,
40-
groups: List[str],
52+
groups: Iterable[str],
4153
output_path: str,
4254
spatial_chunk: int = 4096,
4355
min_dimension: int = 256,
4456
tile_width: int = 256,
4557
max_retries: int = 3,
46-
crs_groups: Optional[List[str]] = None,
47-
gcp_group: Optional[str] = None,
58+
crs_groups: list[str] | None = None,
59+
gcp_group: str | None = None,
4860
) -> xr.DataTree:
4961
"""
5062
Create a GeoZarr-spec 0.4 compliant dataset from EOPF data.
@@ -135,7 +147,7 @@ def create_geozarr_dataset(
135147

136148

137149
def setup_datatree_metadata_geozarr_spec_compliant(
138-
dt: xr.DataTree, groups: List[str], gcp_group: Optional[str] = None
150+
dt: xr.DataTree, groups: Iterable[str], gcp_group: str | None = None
139151
) -> Dict[str, xr.Dataset]:
140152
"""
141153
Set up GeoZarr-spec compliant CF standard names and CRS information.
@@ -152,7 +164,7 @@ def setup_datatree_metadata_geozarr_spec_compliant(
152164
dict[str, xr.Dataset]
153165
Dictionary of datasets with GeoZarr compliance applied
154166
"""
155-
geozarr_groups = {}
167+
geozarr_groups: dict[str, xr.Dataset] = {}
156168
grid_mapping_var_name = "spatial_ref"
157169

158170
for key in groups:
@@ -209,15 +221,15 @@ def setup_datatree_metadata_geozarr_spec_compliant(
209221

210222
def iterative_copy(
211223
dt_input: xr.DataTree,
212-
geozarr_groups: Dict[str, xr.Dataset],
224+
geozarr_groups: dict[str, xr.Dataset],
213225
output_path: str,
214226
compressor: Any,
215227
spatial_chunk: int = 4096,
216228
min_dimension: int = 256,
217229
tile_width: int = 256,
218230
max_retries: int = 3,
219-
crs_groups: Optional[List[str]] = None,
220-
gcp_group: Optional[str] = None,
231+
crs_groups: list[str] | None = None,
232+
gcp_group: str | None = None,
221233
) -> xr.DataTree:
222234
"""
223235
Iteratively copy groups from original DataTree to GeoZarr DataTree.
@@ -332,7 +344,7 @@ def iterative_copy(
332344

333345

334346
def prepare_dataset_with_crs_info(
335-
ds: xr.Dataset, reference_crs: Optional[str] = None
347+
ds: xr.Dataset, reference_crs: str | None = None
336348
) -> xr.Dataset:
337349
"""
338350
Prepare a dataset with CRS information without writing it to disk.
@@ -394,7 +406,7 @@ def write_geozarr_group(
394406
max_retries: int = 3,
395407
min_dimension: int = 256,
396408
tile_width: int = 256,
397-
gcp_group: Optional[str] = None,
409+
gcp_group: str | None = None,
398410
) -> xr.DataTree:
399411
"""
400412
Write a group to a GeoZarr dataset with multiscales support.
@@ -504,7 +516,7 @@ def create_geozarr_compliant_multiscales(
504516
min_dimension: int = 256,
505517
tile_width: int = 256,
506518
spatial_chunk: int = 4096,
507-
ds_gcp: Optional[xr.Dataset] = None,
519+
ds_gcp: xr.Dataset | None = None,
508520
) -> Dict[str, Any]:
509521
"""
510522
Create GeoZarr-spec compliant multiscales following the specification exactly.
@@ -736,7 +748,7 @@ def calculate_overview_levels(
736748
native_height: int,
737749
min_dimension: int = 256,
738750
tile_width: int = 256,
739-
) -> List[Dict[str, Any]]:
751+
) -> list[OverviewLevelJSON]:
740752
"""
741753
Calculate overview levels following COG /2 downsampling logic.
742754
@@ -756,7 +768,7 @@ def calculate_overview_levels(
756768
list
757769
List of overview level dictionaries
758770
"""
759-
overview_levels = []
771+
overview_levels: list[OverviewLevelJSON] = []
760772
level = 0
761773
current_width = native_width
762774
current_height = native_height
@@ -786,10 +798,10 @@ def calculate_overview_levels(
786798

787799
def create_native_crs_tile_matrix_set(
788800
native_crs: Any,
789-
native_bounds: Tuple[float, float, float, float],
790-
overview_levels: List[Dict[str, Any]],
791-
group_prefix: Optional[str] = "",
792-
) -> Dict[str, Any]:
801+
native_bounds: tuple[float, float, float, float],
802+
overview_levels: Iterable[OverviewLevelJSON],
803+
group_prefix: str | None = "",
804+
) -> TileMatrixSetJSON:
793805
"""
794806
Create a custom Tile Matrix Set for the native CRS following GeoZarr spec.
795807
@@ -810,7 +822,7 @@ def create_native_crs_tile_matrix_set(
810822
Tile Matrix Set definition following OGC standard
811823
"""
812824
left, bottom, right, top = native_bounds
813-
tile_matrices = []
825+
tile_matrices: list[TileMatrixJSON] = []
814826

815827
for overview in overview_levels:
816828
level = overview["level"]
@@ -872,7 +884,7 @@ def create_overview_dataset_all_vars(
872884
native_crs: Any,
873885
native_bounds: Tuple[float, float, float, float],
874886
data_vars: List[Hashable],
875-
ds_gcp: Optional[xr.Dataset] = None,
887+
ds_gcp: xr.Dataset | None = None,
876888
) -> xr.Dataset:
877889
"""
878890
Create an overview dataset containing all variables for a specific level.
@@ -914,26 +926,21 @@ def create_overview_dataset_all_vars(
914926

915927
# Check if we're dealing with geographic coordinates (EPSG:4326)
916928
if native_crs and native_crs.to_epsg() == 4326:
917-
x_attrs = {
918-
"_ARRAY_DIMENSIONS": ["x"],
919-
"standard_name": "longitude",
920-
"units": "degrees_east",
921-
"long_name": "longitude",
922-
}
923-
y_attrs = {
924-
"_ARRAY_DIMENSIONS": ["y"],
925-
"standard_name": "latitude",
926-
"units": "degrees_north",
927-
"long_name": "latitude",
929+
lon_attrs = _get_lon_coord_attrs()
930+
lat_attrs = _get_lat_coord_attrs()
931+
overview_coords = {
932+
"x": (["x"], x_coords, lon_attrs),
933+
"y": (["y"], y_coords, lat_attrs),
928934
}
935+
929936
else:
930937
x_attrs = _get_x_coord_attrs()
931938
y_attrs = _get_y_coord_attrs()
932939

933-
overview_coords = {
934-
"x": (["x"], x_coords, x_attrs),
935-
"y": (["y"], y_coords, y_attrs),
936-
}
940+
overview_coords = {
941+
"x": (["x"], x_coords, x_attrs),
942+
"y": (["y"], y_coords, y_attrs),
943+
}
937944

938945
# Determine standard name based on whether this is Sentinel-1 data
939946
# TODO: use a better way to determine this than just checking for ds_gcp
@@ -999,13 +1006,13 @@ def create_overview_dataset_all_vars(
9991006

10001007
def write_dataset_band_by_band_with_validation(
10011008
ds: xr.Dataset,
1002-
existing_dataset: Optional[xr.Dataset],
1009+
existing_dataset: xr.Dataset | None,
10031010
output_path: str,
1004-
encoding: Dict[str, Any],
1011+
encoding: dict[Hashable, XarrayEncodingJSON],
10051012
max_retries: int,
10061013
group_name: str,
10071014
force_overwrite: bool = False,
1008-
) -> Tuple[bool, xr.Dataset]:
1015+
) -> tuple[bool, xr.Dataset]:
10091016
"""
10101017
Write dataset band by band with individual band validation.
10111018
@@ -1185,8 +1192,8 @@ def write_dataset_band_by_band_with_validation(
11851192

11861193
def consolidate_metadata(
11871194
store: StoreLike,
1188-
path: Optional[str] = None,
1189-
zarr_format: Optional[zarr.core.common.ZarrFormat] = None,
1195+
path: str | None = None,
1196+
zarr_format: zarr.core.common.ZarrFormat | None = None,
11901197
) -> zarr.Group:
11911198
"""
11921199
Consolidate metadata of all nodes in a hierarchy.
@@ -1212,8 +1219,8 @@ def consolidate_metadata(
12121219

12131220
async def async_consolidate_metadata(
12141221
store: StoreLike,
1215-
path: Optional[str] = None,
1216-
zarr_format: Optional[zarr.core.common.ZarrFormat] = None,
1222+
path: str | None = None,
1223+
zarr_format: zarr.core.common.ZarrFormat | None = None,
12171224
) -> zarr.core.group.AsyncGroup:
12181225
"""
12191226
Consolidate metadata of all nodes in a hierarchy asynchronously.
@@ -1382,9 +1389,9 @@ def _add_geotransform(ds: xr.Dataset, grid_mapping_var: str) -> None:
13821389
ds[grid_mapping_var].attrs["GeoTransform"] = transform_str
13831390

13841391

1385-
def _find_reference_crs(geozarr_groups: Dict[str, xr.Dataset]) -> Optional[str]:
1392+
def _find_reference_crs(geozarr_groups: Mapping[str, xr.Dataset]) -> str | None:
13861393
"""Find the reference CRS in the geozarr groups."""
1387-
for key, group in geozarr_groups.items():
1394+
for group in geozarr_groups.values():
13881395
if group.rio.crs:
13891396
crs_string: str = group.rio.crs.to_string()
13901397
return crs_string
@@ -1393,9 +1400,10 @@ def _find_reference_crs(geozarr_groups: Dict[str, xr.Dataset]) -> Optional[str]:
13931400

13941401
def _create_encoding(
13951402
ds: xr.Dataset, compressor: Any, spatial_chunk: int
1396-
) -> Dict[str, Any]:
1403+
) -> dict[Hashable, XarrayEncodingJSON]:
13971404
"""Create encoding for dataset variables."""
1398-
encoding: Dict[str, Any] = {}
1405+
encoding: dict[Hashable, XarrayEncodingJSON] = {}
1406+
chunking: tuple[int, ...]
13991407
for var in ds.data_vars:
14001408
if hasattr(ds[var].data, "chunks"):
14011409
current_chunks = ds[var].chunks
@@ -1435,9 +1443,9 @@ def _create_encoding(
14351443

14361444
def _create_geozarr_encoding(
14371445
ds: xr.Dataset, compressor: Any, spatial_chunk: int
1438-
) -> Dict[str, Any]:
1446+
) -> dict[Hashable, XarrayEncodingJSON]:
14391447
"""Create encoding for GeoZarr dataset variables."""
1440-
encoding: Dict[str, Any] = {}
1448+
encoding: dict[Hashable, XarrayEncodingJSON] = {}
14411449
for var in ds.data_vars:
14421450
if utils.is_grid_mapping_variable(ds, var):
14431451
encoding[var] = {"compressors": None}
@@ -1465,7 +1473,7 @@ def _create_geozarr_encoding(
14651473
return encoding
14661474

14671475

1468-
def _load_existing_dataset(path: str) -> Optional[xr.Dataset]:
1476+
def _load_existing_dataset(path: str) -> xr.Dataset | None:
14691477
"""Load existing dataset if it exists."""
14701478
try:
14711479
if fs_utils.path_exists(path):
@@ -1484,10 +1492,10 @@ def _load_existing_dataset(path: str) -> Optional[xr.Dataset]:
14841492

14851493

14861494
def _create_tile_matrix_limits(
1487-
overview_levels: List[Dict[str, Any]], tile_width: int
1488-
) -> Dict[str, Any]:
1495+
overview_levels: Iterable[OverviewLevelJSON], tile_width: int
1496+
) -> dict[str, TileMatrixLimitJSON]:
14891497
"""Create tile matrix limits for overview levels."""
1490-
tile_matrix_limits = {}
1498+
tile_matrix_limits: dict[str, TileMatrixLimitJSON] = {}
14911499
for ol in overview_levels:
14921500
level_str = str(ol["level"])
14931501
max_tile_col = int(np.ceil(ol["width"] / tile_width)) - 1
@@ -1500,10 +1508,11 @@ def _create_tile_matrix_limits(
15001508
"minTileRow": 0,
15011509
"maxTileRow": max_tile_row,
15021510
}
1511+
15031512
return tile_matrix_limits
15041513

15051514

1506-
def _get_x_coord_attrs() -> Dict[str, Any]:
1515+
def _get_x_coord_attrs() -> StandardXCoordAttrsJSON:
15071516
"""Get standard attributes for x coordinate."""
15081517
return {
15091518
"units": "m",
@@ -1513,7 +1522,7 @@ def _get_x_coord_attrs() -> Dict[str, Any]:
15131522
}
15141523

15151524

1516-
def _get_y_coord_attrs() -> Dict[str, Any]:
1525+
def _get_y_coord_attrs() -> StandardYCoordAttrsJSON:
15171526
"""Get standard attributes for y coordinate."""
15181527
return {
15191528
"units": "m",
@@ -1523,25 +1532,27 @@ def _get_y_coord_attrs() -> Dict[str, Any]:
15231532
}
15241533

15251534

1526-
def _get_at_coord_attrs() -> Dict[str, Any]:
1527-
"""Get standard attributes for azimuth_time coordinate."""
1535+
def _get_lon_coord_attrs() -> StandardLonCoordAttrsJSON:
1536+
"""Get standard attributes for longitude coordinate."""
15281537
return {
1529-
"long_name": "azimuth time",
1530-
"standard_name": "time",
1531-
"_ARRAY_DIMENSIONS": ["azimuth_time"],
1538+
"units": "degrees_east",
1539+
"long_name": "longitude",
1540+
"standard_name": "longitude",
1541+
"_ARRAY_DIMENSIONS": ["x"],
15321542
}
15331543

15341544

1535-
def _get_gr_coord_attrs() -> Dict[str, Any]:
1536-
"""Get standard attributes for ground_range coordinate."""
1545+
def _get_lat_coord_attrs() -> StandardLatCoordAttrsJSON:
1546+
"""Get standard attributes for latitude coordinate."""
15371547
return {
1538-
"long_name": "ground range distance",
1539-
"standard_name": "projection_x_coordinate",
1540-
"_ARRAY_DIMENSIONS": ["ground_range"],
1548+
"units": "degrees_north",
1549+
"long_name": "latitude",
1550+
"standard_name": "latitude",
1551+
"_ARRAY_DIMENSIONS": ["y"],
15411552
}
15421553

15431554

1544-
def _find_grid_mapping_var_name(ds: xr.Dataset, data_vars: List[Hashable]) -> str:
1555+
def _find_grid_mapping_var_name(ds: xr.Dataset, data_vars: list[Hashable]) -> str:
15451556
"""Find the grid_mapping variable name from the dataset."""
15461557
grid_mapping_var_name = ds.attrs.get("grid_mapping", None)
15471558
if not grid_mapping_var_name and data_vars:

0 commit comments

Comments
 (0)