Skip to content

Commit f171f46

Browse files
authored
Merge pull request #36 from d-v-b/chore/normalize-type-hints
chore/normalize type hints
2 parents 57024d6 + 20c2cf6 commit f171f46

3 files changed

Lines changed: 76 additions & 37 deletions

File tree

src/eopf_geozarr/conversion/fs_utils.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,16 @@
22

33
import json
44
import os
5-
from typing import Any, Dict, Optional
5+
from collections.abc import Mapping
6+
from typing import Any
67
from urllib.parse import urlparse
78

89
import s3fs
910
import zarr
11+
from fsspec.implementations.local import LocalFileSystem
12+
from s3fs import S3FileSystem
13+
14+
from eopf_geozarr.types import S3Credentials, S3FsOptions
1015

1116

1217
def normalize_s3_path(s3_path: str) -> str:
@@ -88,7 +93,7 @@ def parse_s3_path(s3_path: str) -> tuple[str, str]:
8893
return bucket, key
8994

9095

91-
def get_s3_storage_options(s3_path: str, **s3_kwargs: Any) -> Dict[str, Any]:
96+
def get_s3_storage_options(s3_path: str, **s3_kwargs: Any) -> S3FsOptions:
9297
"""
9398
Get storage options for S3 access with xarray.
9499
@@ -105,7 +110,7 @@ def get_s3_storage_options(s3_path: str, **s3_kwargs: Any) -> Dict[str, Any]:
105110
Storage options dictionary for xarray
106111
"""
107112
# Set up default S3 configuration
108-
default_s3_kwargs = {
113+
default_s3_kwargs: S3FsOptions = {
109114
"anon": False, # Use credentials
110115
"use_ssl": True,
111116
"client_kwargs": {
@@ -121,12 +126,12 @@ def get_s3_storage_options(s3_path: str, **s3_kwargs: Any) -> Dict[str, Any]:
121126
client_kwargs["endpoint_url"] = os.environ["AWS_ENDPOINT_URL"]
122127

123128
# Merge with user-provided kwargs
124-
s3_config = {**default_s3_kwargs, **s3_kwargs}
129+
s3_config: S3FsOptions = {**default_s3_kwargs, **s3_kwargs} # type: ignore[typeddict-item]
125130

126131
return s3_config
127132

128133

129-
def get_storage_options(path: str, **kwargs: Any) -> Optional[Dict[str, Any]]:
134+
def get_storage_options(path: str, **kwargs: Any) -> S3FsOptions | None:
130135
"""
131136
Get storage options for any URL type, leveraging fsspec as the abstraction layer.
132137
@@ -202,7 +207,7 @@ def create_s3_store(s3_path: str, **s3_kwargs: Any) -> str:
202207

203208

204209
def write_s3_json_metadata(
205-
s3_path: str, metadata: Dict[str, Any], **s3_kwargs: Any
210+
s3_path: str, metadata: Mapping[str, Any], **s3_kwargs: Any
206211
) -> None:
207212
"""
208213
Write JSON metadata directly to S3.
@@ -220,7 +225,7 @@ def write_s3_json_metadata(
220225
Additional keyword arguments for s3fs.S3FileSystem
221226
"""
222227
# Set up default S3 configuration
223-
default_s3_kwargs = {
228+
default_s3_kwargs: S3FsOptions = {
224229
"anon": False,
225230
"use_ssl": True,
226231
"asynchronous": False, # Force synchronous mode
@@ -245,7 +250,7 @@ def write_s3_json_metadata(
245250
f.write(json_content)
246251

247252

248-
def read_s3_json_metadata(s3_path: str, **s3_kwargs: Any) -> Dict[str, Any]:
253+
def read_s3_json_metadata(s3_path: str, **s3_kwargs: Any) -> dict[str, Any]:
249254
"""
250255
Read JSON metadata from S3.
251256
@@ -262,7 +267,7 @@ def read_s3_json_metadata(s3_path: str, **s3_kwargs: Any) -> Dict[str, Any]:
262267
Parsed JSON metadata
263268
"""
264269
# Set up default S3 configuration
265-
default_s3_kwargs = {
270+
default_s3_kwargs: S3FsOptions = {
266271
"anon": False,
267272
"use_ssl": True,
268273
"asynchronous": False, # Force synchronous mode
@@ -284,7 +289,7 @@ def read_s3_json_metadata(s3_path: str, **s3_kwargs: Any) -> Dict[str, Any]:
284289
with fs.open(s3_path, "r") as f:
285290
content = f.read()
286291

287-
result: Dict[str, Any] = json.loads(content)
292+
result: dict[str, Any] = json.loads(content)
288293
return result
289294

290295

@@ -304,7 +309,7 @@ def s3_path_exists(s3_path: str, **s3_kwargs: Any) -> bool:
304309
bool
305310
True if the path exists
306311
"""
307-
default_s3_kwargs = {
312+
default_s3_kwargs: S3FsOptions = {
308313
"anon": False,
309314
"use_ssl": True,
310315
"asynchronous": False, # Force synchronous mode
@@ -351,7 +356,7 @@ def open_s3_zarr_group(s3_path: str, mode: str = "r", **s3_kwargs: Any) -> zarr.
351356
)
352357

353358

354-
def get_s3_credentials_info() -> Dict[str, Optional[str]]:
359+
def get_s3_credentials_info() -> S3Credentials:
355360
"""
356361
Get information about available S3 credentials.
357362
@@ -372,7 +377,7 @@ def get_s3_credentials_info() -> Dict[str, Optional[str]]:
372377
}
373378

374379

375-
def validate_s3_access(s3_path: str, **s3_kwargs: Any) -> tuple[bool, Optional[str]]:
380+
def validate_s3_access(s3_path: str, **s3_kwargs: Any) -> tuple[bool, str | None]:
376381
"""
377382
Validate that we can access the S3 path.
378383
@@ -389,9 +394,9 @@ def validate_s3_access(s3_path: str, **s3_kwargs: Any) -> tuple[bool, Optional[s
389394
Tuple of (success, error_message)
390395
"""
391396
try:
392-
bucket, key = parse_s3_path(s3_path)
397+
bucket, _ = parse_s3_path(s3_path)
393398

394-
default_s3_kwargs = {
399+
default_s3_kwargs: S3FsOptions = {
395400
"anon": False,
396401
"use_ssl": True,
397402
"asynchronous": False, # Force synchronous mode
@@ -419,7 +424,7 @@ def validate_s3_access(s3_path: str, **s3_kwargs: Any) -> tuple[bool, Optional[s
419424
return False, str(e)
420425

421426

422-
def get_filesystem(path: str, **kwargs: Any) -> Any:
427+
def get_filesystem(path: str, **kwargs: Any) -> LocalFileSystem | S3FileSystem:
423428
"""
424429
Get the appropriate fsspec filesystem for any path type.
425430
@@ -435,18 +440,17 @@ def get_filesystem(path: str, **kwargs: Any) -> Any:
435440
fsspec.AbstractFileSystem
436441
Filesystem instance
437442
"""
438-
import fsspec
439443

440444
if is_s3_path(path):
441445
# Get S3 storage options and use them for fsspec
442446
storage_options = get_s3_storage_options(path, **kwargs)
443-
return fsspec.filesystem("s3", **storage_options)
447+
return S3FileSystem(**storage_options)
444448
else:
445449
# For local paths, use the local filesystem
446-
return fsspec.filesystem("file")
450+
return LocalFileSystem(**kwargs)
447451

448452

449-
def write_json_metadata(path: str, metadata: Dict[str, Any], **kwargs: Any) -> None:
453+
def write_json_metadata(path: str, metadata: dict[str, Any], **kwargs: Any) -> None:
450454
"""
451455
Write JSON metadata to any path type using fsspec.
452456
@@ -473,7 +477,7 @@ def write_json_metadata(path: str, metadata: Dict[str, Any], **kwargs: Any) -> N
473477
f.write(json_content)
474478

475479

476-
def read_json_metadata(path: str, **kwargs: Any) -> Dict[str, Any]:
480+
def read_json_metadata(path: str, **kwargs: Any) -> dict[str, Any]:
477481
"""
478482
Read JSON metadata from any path type using fsspec.
479483
@@ -494,7 +498,7 @@ def read_json_metadata(path: str, **kwargs: Any) -> Dict[str, Any]:
494498
with fs.open(path, "r") as f:
495499
content = f.read()
496500

497-
result: Dict[str, Any] = json.loads(content)
501+
result: dict[str, Any] = json.loads(content)
498502
return result
499503

500504

src/eopf_geozarr/conversion/geozarr.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import os
1919
import shutil
2020
import time
21-
from collections.abc import Hashable, Iterable, Mapping
22-
from typing import Any, Dict, List, Tuple
21+
from collections.abc import Hashable, Iterable, Mapping, Sequence
22+
from typing import Any, Dict, Tuple
2323

2424
import numpy as np
2525
import xarray as xr
@@ -55,7 +55,7 @@ def create_geozarr_dataset(
5555
min_dimension: int = 256,
5656
tile_width: int = 256,
5757
max_retries: int = 3,
58-
crs_groups: list[str] | None = None,
58+
crs_groups: Iterable[str] | None = None,
5959
gcp_group: str | None = None,
6060
) -> xr.DataTree:
6161
"""
@@ -77,8 +77,8 @@ def create_geozarr_dataset(
7777
Tile width for TMS compatibility
7878
max_retries : int, default 3
7979
Maximum number of retries for network operations
80-
crs_groups : list[str], optional
81-
List of group names that need CRS information added on best-effort basis
80+
crs_groups : Iterabl[str], optional
81+
Iterable of group names that need CRS information added on best-effort basis
8282
gcp_group : str, optional
8383
Group name where GCPs (Ground Control Points) are located.
8484
@@ -228,7 +228,7 @@ def iterative_copy(
228228
min_dimension: int = 256,
229229
tile_width: int = 256,
230230
max_retries: int = 3,
231-
crs_groups: list[str] | None = None,
231+
crs_groups: Iterable[str] | None = None,
232232
gcp_group: str | None = None,
233233
) -> xr.DataTree:
234234
"""
@@ -252,8 +252,8 @@ def iterative_copy(
252252
Tile width for TMS compatibility
253253
max_retries : int, default 3
254254
Maximum number of retries for network operations
255-
crs_groups : list[str], optional
256-
List of group names that need CRS information added on best-effort basis
255+
crs_groups : Iterable[str], optional
256+
Iterable of group names that need CRS information added on best-effort basis
257257
gcp_group : str, optional
258258
Group name where GCPs (Ground Control Points) are located
259259
@@ -273,7 +273,7 @@ def iterative_copy(
273273
storage_options=storage_options,
274274
)
275275

276-
written_groups = set()
276+
written_groups: set[str] = set()
277277
reference_crs = None
278278

279279
# Process all groups in the tree using iterative approach
@@ -811,8 +811,8 @@ def create_native_crs_tile_matrix_set(
811811
Native CRS of the data
812812
native_bounds : tuple
813813
Native bounds (left, bottom, right, top)
814-
overview_levels : list
815-
List of overview level dictionaries
814+
overview_levels : Iterable[OverViewLevelJSON]
815+
Iterable of overview level dictionaries
816816
group_prefix : str, optional
817817
Group prefix for the tile matrix IDs
818818
@@ -883,7 +883,7 @@ def create_overview_dataset_all_vars(
883883
height: int,
884884
native_crs: Any,
885885
native_bounds: Tuple[float, float, float, float],
886-
data_vars: List[Hashable],
886+
data_vars: Sequence[Hashable],
887887
ds_gcp: xr.Dataset | None = None,
888888
) -> xr.Dataset:
889889
"""
@@ -903,8 +903,8 @@ def create_overview_dataset_all_vars(
903903
Native CRS of the data
904904
native_bounds : tuple
905905
Native bounds (left, bottom, right, top)
906-
data_vars : list
907-
List of data variable names to include
906+
data_vars : Sequence[Hashable]
907+
Sequence of data variable names to include
908908
ds_gcp : xr.Dataset, optional
909909
Source dataset with Sentinel-1 ground control points
910910
at native resolution
@@ -1552,7 +1552,7 @@ def _get_lat_coord_attrs() -> StandardLatCoordAttrsJSON:
15521552
}
15531553

15541554

1555-
def _find_grid_mapping_var_name(ds: xr.Dataset, data_vars: list[Hashable]) -> str:
1555+
def _find_grid_mapping_var_name(ds: xr.Dataset, data_vars: Sequence[Hashable]) -> str:
15561556
"""Find the grid_mapping variable name from the dataset."""
15571557
grid_mapping_var_name = ds.attrs.get("grid_mapping", None)
15581558
if not grid_mapping_var_name and data_vars:

src/eopf_geozarr/types.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,38 @@ class TileMatrixSetJSON(TypedDict):
9393
]
9494
"""A string literal indicating a resampling method"""
9595
XARRAY_DIMS_KEY: Final = "_ARRAY_DIMENSIONS"
96+
97+
98+
# Why is endpoint URL specified twice?
99+
class S3ClientOptions(TypedDict):
100+
"""
101+
S3 client options
102+
"""
103+
104+
region_name: NotRequired[str]
105+
endpoint_url: NotRequired[str]
106+
107+
108+
class S3FsOptions(TypedDict):
109+
"""
110+
S3FS options
111+
"""
112+
113+
anon: NotRequired[bool]
114+
use_ssl: NotRequired[bool]
115+
client_kwargs: NotRequired[S3ClientOptions]
116+
endpoint_url: NotRequired[str]
117+
asynchronous: NotRequired[bool]
118+
119+
120+
class S3Credentials(TypedDict):
121+
"""
122+
S3 credentials
123+
"""
124+
125+
aws_access_key_id: str | None
126+
aws_secret_access_key: str | None
127+
aws_session_token: str | None
128+
aws_default_region: str
129+
aws_profile: str | None
130+
AWS_ENDPOINT_URL: str | None

0 commit comments

Comments
 (0)