Skip to content

Commit c24b456

Browse files
Enhance GeoZarr conversion: add support for Dask cluster setup for parallel processing, update chunking strategy for datasets, and ensure compatibility with existing S3 handling.
1 parent eaef673 commit c24b456

3 files changed

Lines changed: 136 additions & 62 deletions

File tree

.vscode/launch.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
"--min-dimension", "256",
4545
"--tile-width", "256",
4646
"--max-retries", "2",
47+
"--dask-cluster",
4748
"--verbose"
4849
],
4950
"cwd": "${workspaceFolder}",

eopf_geozarr/cli.py

Lines changed: 120 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,56 @@
88
import argparse
99
import sys
1010
from pathlib import Path
11+
from typing import Optional
1112

1213
import xarray as xr
1314

1415
from . import create_geozarr_dataset
1516
from .conversion import is_s3_path, validate_s3_access, get_s3_credentials_info
1617

1718

19+
def setup_dask_cluster(enable_dask: bool, verbose: bool = False) -> Optional[object]:
20+
"""
21+
Set up a dask cluster for parallel processing.
22+
23+
Parameters
24+
----------
25+
enable_dask : bool
26+
Whether to enable dask cluster
27+
verbose : bool, default False
28+
Enable verbose output
29+
30+
Returns
31+
-------
32+
dask.distributed.Client or None
33+
Dask client if enabled, None otherwise
34+
"""
35+
if not enable_dask:
36+
return None
37+
38+
try:
39+
from dask.distributed import Client
40+
41+
# Set up local cluster
42+
client = Client() # set up local cluster
43+
44+
if verbose:
45+
print(f"🚀 Dask cluster started: {client}")
46+
print(f" Dashboard: {client.dashboard_link}")
47+
print(f" Workers: {len(client.scheduler_info()['workers'])}")
48+
else:
49+
print("🚀 Dask cluster started for parallel processing")
50+
51+
return client
52+
53+
except ImportError:
54+
print("❌ Error: dask.distributed not available. Install with: pip install 'dask[distributed]'")
55+
sys.exit(1)
56+
except Exception as e:
57+
print(f"❌ Error starting dask cluster: {e}")
58+
sys.exit(1)
59+
60+
1861
def convert_command(args: argparse.Namespace) -> None:
1962
"""
2063
Convert EOPF dataset to GeoZarr compliant format.
@@ -24,64 +67,70 @@ def convert_command(args: argparse.Namespace) -> None:
2467
args : argparse.Namespace
2568
Command line arguments
2669
"""
27-
# Validate input path (handle both local paths and URLs)
28-
input_path_str = args.input_path
29-
if input_path_str.startswith(("http://", "https://", "s3://", "gs://")):
30-
# URL - no local validation needed
31-
input_path = input_path_str
32-
else:
33-
# Local path - validate existence
34-
input_path = Path(input_path_str)
35-
if not input_path.exists():
36-
print(f"Error: Input path {input_path} does not exist")
37-
sys.exit(1)
38-
input_path = str(input_path)
39-
40-
# Handle output path validation
41-
output_path_str = args.output_path
42-
if is_s3_path(output_path_str):
43-
# S3 path - validate S3 access
44-
print("🔍 Validating S3 access...")
45-
success, error_msg = validate_s3_access(output_path_str)
46-
if not success:
47-
print(f"❌ Error: Cannot access S3 path {output_path_str}")
48-
print(f" Reason: {error_msg}")
49-
print("\n💡 S3 Configuration Help:")
50-
print(" Make sure you have S3 credentials configured:")
51-
print(" - Set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables")
52-
print(" - Set AWS_DEFAULT_REGION (default: us-east-1)")
53-
print(" - For custom S3 providers (e.g., OVH Cloud), set AWS_S3_ENDPOINT")
54-
print(" - Or configure AWS CLI with 'aws configure'")
55-
print(" - Or use IAM roles if running on EC2")
56-
57-
if args.verbose:
58-
creds_info = get_s3_credentials_info()
59-
print(f"\n🔧 Current AWS configuration:")
60-
for key, value in creds_info.items():
61-
print(f" {key}: {value or 'Not set'}")
70+
# Set up dask cluster if requested
71+
dask_client = setup_dask_cluster(
72+
enable_dask=getattr(args, 'dask_cluster', False),
73+
verbose=args.verbose
74+
)
75+
76+
try:
77+
# Validate input path (handle both local paths and URLs)
78+
input_path_str = args.input_path
79+
if input_path_str.startswith(("http://", "https://", "s3://", "gs://")):
80+
# URL - no local validation needed
81+
input_path = input_path_str
82+
else:
83+
# Local path - validate existence
84+
input_path = Path(input_path_str)
85+
if not input_path.exists():
86+
print(f"Error: Input path {input_path} does not exist")
87+
sys.exit(1)
88+
input_path = str(input_path)
89+
90+
# Handle output path validation
91+
output_path_str = args.output_path
92+
if is_s3_path(output_path_str):
93+
# S3 path - validate S3 access
94+
print("🔍 Validating S3 access...")
95+
success, error_msg = validate_s3_access(output_path_str)
96+
if not success:
97+
print(f"❌ Error: Cannot access S3 path {output_path_str}")
98+
print(f" Reason: {error_msg}")
99+
print("\n💡 S3 Configuration Help:")
100+
print(" Make sure you have S3 credentials configured:")
101+
print(" - Set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables")
102+
print(" - Set AWS_DEFAULT_REGION (default: us-east-1)")
103+
print(" - For custom S3 providers (e.g., OVH Cloud), set AWS_S3_ENDPOINT")
104+
print(" - Or configure AWS CLI with 'aws configure'")
105+
print(" - Or use IAM roles if running on EC2")
106+
107+
if args.verbose:
108+
creds_info = get_s3_credentials_info()
109+
print(f"\n🔧 Current AWS configuration:")
110+
for key, value in creds_info.items():
111+
print(f" {key}: {value or 'Not set'}")
112+
113+
sys.exit(1)
62114

63-
sys.exit(1)
64-
65-
print("✅ S3 access validated successfully")
66-
output_path = output_path_str
67-
else:
68-
# Local path - create directory if it doesn't exist
69-
output_path = Path(output_path_str)
70-
output_path.parent.mkdir(parents=True, exist_ok=True)
71-
output_path = str(output_path)
72-
73-
if args.verbose:
74-
print(f"Loading EOPF dataset from: {input_path}")
75-
print(f"Groups to convert: {args.groups}")
76-
print(f"Output path: {output_path}")
77-
print(f"Spatial chunk size: {args.spatial_chunk}")
78-
print(f"Min dimension: {args.min_dimension}")
79-
print(f"Tile width: {args.tile_width}")
115+
print("✅ S3 access validated successfully")
116+
output_path = output_path_str
117+
else:
118+
# Local path - create directory if it doesn't exist
119+
output_path = Path(output_path_str)
120+
output_path.parent.mkdir(parents=True, exist_ok=True)
121+
output_path = str(output_path)
122+
123+
if args.verbose:
124+
print(f"Loading EOPF dataset from: {input_path}")
125+
print(f"Groups to convert: {args.groups}")
126+
print(f"Output path: {output_path}")
127+
print(f"Spatial chunk size: {args.spatial_chunk}")
128+
print(f"Min dimension: {args.min_dimension}")
129+
print(f"Tile width: {args.tile_width}")
80130

81-
try:
82131
# Load the EOPF DataTree
83132
print("Loading EOPF dataset...")
84-
dt = xr.open_datatree(str(input_path), engine="zarr")
133+
dt = xr.open_datatree(str(input_path), engine="zarr", chunks="auto")
85134

86135
if args.verbose:
87136
print(f"Loaded DataTree with {len(dt.children)} groups")
@@ -117,6 +166,16 @@ def convert_command(args: argparse.Namespace) -> None:
117166

118167
traceback.print_exc()
119168
sys.exit(1)
169+
finally:
170+
# Clean up dask client if it was created
171+
if dask_client is not None:
172+
try:
173+
dask_client.close()
174+
if args.verbose:
175+
print("🔄 Dask cluster closed")
176+
except Exception as e:
177+
if args.verbose:
178+
print(f"Warning: Error closing dask cluster: {e}")
120179

121180

122181
def info_command(args: argparse.Namespace) -> None:
@@ -143,7 +202,7 @@ def info_command(args: argparse.Namespace) -> None:
143202

144203
try:
145204
print(f"Loading dataset from: {input_path}")
146-
dt = xr.open_datatree(input_path, engine="zarr")
205+
dt = xr.open_datatree(input_path, engine="zarr", chunks="auto")
147206

148207
print("\nDataset Information:")
149208
print("==================")
@@ -185,7 +244,7 @@ def validate_command(args: argparse.Namespace) -> None:
185244

186245
try:
187246
print(f"Validating GeoZarr compliance for: {input_path}")
188-
dt = xr.open_datatree(input_path, engine="zarr")
247+
dt = xr.open_datatree(input_path, engine="zarr", chunks="auto")
189248

190249
compliance_issues = []
191250
total_variables = 0
@@ -306,6 +365,11 @@ def create_parser() -> argparse.ArgumentParser:
306365
help="Maximum number of retries for network operations (default: 3)",
307366
)
308367
convert_parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
368+
convert_parser.add_argument(
369+
"--dask-cluster",
370+
action="store_true",
371+
help="Start a local dask cluster for parallel processing of chunks"
372+
)
309373
convert_parser.set_defaults(func=convert_command)
310374

311375
# Info command

eopf_geozarr/conversion/geozarr.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,17 @@ def recursive_copy(
269269
if dt_node.data_vars:
270270
# Set up encoding for variables
271271
for var in ds.data_vars:
272+
# create chunks to the size of the data
273+
data_shape = ds[var].shape
274+
if len(data_shape) >= 2:
275+
chunking = (
276+
1,
277+
data_shape[-2],
278+
data_shape[-1],
279+
)
272280
encoding[var] = {
273281
"compressors": [compressor],
282+
"chunks": chunking
274283
}
275284
for coord in ds.coords:
276285
encoding[coord] = {
@@ -527,11 +536,11 @@ def write_geozarr_group(
527536
if s3_utils.s3_path_exists(native_dataset_path):
528537
store = s3_utils.create_s3_store(native_dataset_path)
529538
storage_options = s3_utils.get_s3_storage_options(native_dataset_path)
530-
existing_native_dataset = xr.open_zarr(store, zarr_format=3, storage_options=storage_options)
539+
existing_native_dataset = xr.open_zarr(store, zarr_format=3, storage_options=storage_options, chunks="auto")
531540
print(f"Found existing native dataset at {native_dataset_path}")
532541
else:
533542
if os.path.exists(native_dataset_path):
534-
existing_native_dataset = xr.open_zarr(native_dataset_path, zarr_format=3)
543+
existing_native_dataset = xr.open_zarr(native_dataset_path, zarr_format=3, chunks="auto")
535544
print(f"Found existing native dataset at {native_dataset_path}")
536545
except Exception as e:
537546
print(f"Warning: Could not open existing native dataset at {native_dataset_path}: {e}")
@@ -583,10 +592,10 @@ def write_geozarr_group(
583592
if s3_utils.is_s3_path(output_path):
584593
store = s3_utils.create_s3_store(group_path)
585594
storage_options = s3_utils.get_s3_storage_options(output_path)
586-
ds = xr.open_dataset(store, engine="zarr", zarr_format=3, decode_coords="all", storage_options=storage_options).compute()
595+
ds = xr.open_dataset(store, engine="zarr", zarr_format=3, decode_coords="all", storage_options=storage_options, chunks="auto").compute()
587596
else:
588597
ds = xr.open_dataset(
589-
group_path, engine="zarr", zarr_format=3, decode_coords="all"
598+
group_path, engine="zarr", zarr_format=3, decode_coords="all", chunks="auto"
590599
).compute()
591600

592601
return ds
@@ -1292,9 +1301,9 @@ def write_dataset_band_by_band_with_validation(
12921301
if s3_utils.is_s3_path(output_path):
12931302
store = s3_utils.create_s3_store(output_path)
12941303
storage_options = s3_utils.get_s3_storage_options(output_path)
1295-
ds = xr.open_dataset(store, engine="zarr", zarr_format=3, decode_coords="all", storage_options=storage_options).compute()
1304+
ds = xr.open_dataset(store, engine="zarr", zarr_format=3, decode_coords="all", storage_options=storage_options, chunks="auto").compute()
12961305
else:
1297-
ds = xr.open_dataset(output_path, engine="zarr", zarr_format=3, decode_coords="all").compute()
1306+
ds = xr.open_dataset(output_path, engine="zarr", zarr_format=3, decode_coords="all", chunks="auto").compute()
12981307

12991308
# Report results
13001309
if failed_vars:

0 commit comments

Comments
 (0)