@@ -57,6 +57,7 @@ def create_geozarr_dataset(
5757 max_retries : int = 3 ,
5858 crs_groups : Iterable [str ] | None = None ,
5959 gcp_group : str | None = None ,
60+ enable_sharding : bool = False ,
6061) -> xr .DataTree :
6162 """
6263 Create a GeoZarr-spec 0.4 compliant dataset from EOPF data.
@@ -81,6 +82,8 @@ def create_geozarr_dataset(
8182 Iterable of group names that need CRS information added on best-effort basis
8283 gcp_group : str, optional
8384 Group name where GCPs (Ground Control Points) are located.
85+ enable_sharding : bool, default False
86+ Enable zarr sharding for spatial dimensions of each variable
8487
8588 Returns
8689 -------
@@ -90,6 +93,9 @@ def create_geozarr_dataset(
9093 dt = dt_input .copy ()
9194 compressor = BloscCodec (cname = "zstd" , clevel = 3 , shuffle = "shuffle" , blocksize = 0 )
9295
96+ if enable_sharding :
97+ print ("🔧 Zarr sharding enabled for spatial dimensions" )
98+
9399 if _is_sentinel1 (dt_input ):
94100 if gcp_group is None :
95101 raise ValueError (
@@ -132,6 +138,7 @@ def create_geozarr_dataset(
132138 max_retries ,
133139 crs_groups ,
134140 gcp_group ,
141+ enable_sharding ,
135142 )
136143
137144 # Consolidate metadata at the root level AFTER all groups are written
@@ -230,6 +237,7 @@ def iterative_copy(
230237 max_retries : int = 3 ,
231238 crs_groups : Iterable [str ] | None = None ,
232239 gcp_group : str | None = None ,
240+ enable_sharding : bool = False ,
233241) -> xr .DataTree :
234242 """
235243 Iteratively copy groups from original DataTree to GeoZarr DataTree.
@@ -301,6 +309,7 @@ def iterative_copy(
301309 min_dimension = min_dimension ,
302310 tile_width = tile_width ,
303311 gcp_group = gcp_group ,
312+ enable_sharding = enable_sharding ,
304313 )
305314 written_groups .add (current_group_path )
306315 continue
@@ -407,6 +416,7 @@ def write_geozarr_group(
407416 min_dimension : int = 256 ,
408417 tile_width : int = 256 ,
409418 gcp_group : str | None = None ,
419+ enable_sharding : bool = False ,
410420) -> xr .DataTree :
411421 """
412422 Write a group to a GeoZarr dataset with multiscales support.
@@ -451,7 +461,7 @@ def write_geozarr_group(
451461 dt .attrs = ds .attrs .copy ()
452462
453463 # Create encoding for all variables
454- encoding = _create_geozarr_encoding (ds , compressor , spatial_chunk )
464+ encoding = _create_geozarr_encoding (ds , compressor , spatial_chunk , enable_sharding )
455465
456466 # Write native data in the group 0 (overview level 0)
457467 native_dataset_group_name = f"{ group_name } /0"
@@ -492,6 +502,7 @@ def write_geozarr_group(
492502 tile_width = tile_width ,
493503 spatial_chunk = spatial_chunk ,
494504 ds_gcp = ds_gcp ,
505+ enable_sharding = enable_sharding ,
495506 )
496507 except Exception as e :
497508 print (
@@ -517,6 +528,7 @@ def create_geozarr_compliant_multiscales(
517528 tile_width : int = 256 ,
518529 spatial_chunk : int = 4096 ,
519530 ds_gcp : xr .Dataset | None = None ,
531+ enable_sharding : bool = False ,
520532) -> Dict [str , Any ]:
521533 """
522534 Create GeoZarr-spec compliant multiscales following the specification exactly.
@@ -674,10 +686,13 @@ def create_geozarr_compliant_multiscales(
674686 native_bounds ,
675687 data_vars ,
676688 ds_gcp_overview ,
689+ enable_sharding ,
677690 )
678691
679692 # Create encoding for this overview level
680- encoding = _create_geozarr_encoding (overview_ds , compressor , spatial_chunk )
693+ encoding = _create_geozarr_encoding (
694+ overview_ds , compressor , spatial_chunk , enable_sharding
695+ )
681696
682697 # Write overview level
683698 overview_path = fs_utils .normalize_path (f"{ output_path } /{ group_name } /{ level } " )
@@ -885,6 +900,7 @@ def create_overview_dataset_all_vars(
885900 native_bounds : Tuple [float , float , float , float ],
886901 data_vars : Sequence [Hashable ],
887902 ds_gcp : xr .Dataset | None = None ,
903+ enable_sharding : bool = False ,
888904) -> xr .Dataset :
889905 """
890906 Create an overview dataset containing all variables for a specific level.
@@ -1090,7 +1106,21 @@ def write_dataset_band_by_band_with_validation(
10901106 for attempt in range (max_retries ):
10911107 try :
10921108 # Ensure the dataset is properly chunked to align with encoding
1093- if var in var_encoding and "chunks" in var_encoding [var ]:
1109+ if (
1110+ var in var_encoding
1111+ and "shards" in var_encoding [var ]
1112+ and var_encoding [var ]["shards" ] is not None
1113+ ):
1114+ # For sharded variables, use the shards dimensions
1115+ shard_dims = var_encoding [var ].get ("shards" , None )
1116+ if shard_dims is not None :
1117+ var_dims = single_var_ds [var ].dims
1118+ chunk_dict = {}
1119+ for i , dim in enumerate (var_dims ):
1120+ if i < len (shard_dims ):
1121+ chunk_dict [dim ] = shard_dims [i ]
1122+ single_var_ds [var ] = single_var_ds [var ].chunk (chunk_dict )
1123+ elif var in var_encoding and "chunks" in var_encoding [var ]:
10941124 target_chunks = var_encoding [var ]["chunks" ]
10951125 # Create chunk dict using the actual dimensions of the variable
10961126 var_dims = single_var_ds [var ].dims
@@ -1442,10 +1472,11 @@ def _create_encoding(
14421472
14431473
14441474def _create_geozarr_encoding (
1445- ds : xr .Dataset , compressor : Any , spatial_chunk : int
1475+ ds : xr .Dataset , compressor : Any , spatial_chunk : int , enable_sharding : bool = False
14461476) -> dict [Hashable , XarrayEncodingJSON ]:
14471477 """Create encoding for GeoZarr dataset variables."""
14481478 encoding : dict [Hashable , XarrayEncodingJSON ] = {}
1479+ chunks : tuple [int , ...]
14491480 for var in ds .data_vars :
14501481 if utils .is_grid_mapping_variable (ds , var ):
14511482 encoding [var ] = {"compressors" : None }
@@ -1458,12 +1489,54 @@ def _create_geozarr_encoding(
14581489 utils .calculate_aligned_chunk_size (width , spatial_chunk ),
14591490 utils .calculate_aligned_chunk_size (height , spatial_chunk ),
14601491 )
1492+
1493+ if len (data_shape ) == 3 :
1494+ chunks = (1 , spatial_chunk_aligned , spatial_chunk_aligned )
1495+ else :
1496+ chunks = (spatial_chunk_aligned , spatial_chunk_aligned )
14611497 else :
14621498 spatial_chunk_aligned = spatial_chunk
1499+ chunks = (spatial_chunk_aligned ,)
1500+
1501+ shards : tuple [int , ...] | None = None
1502+
1503+ if enable_sharding :
1504+ # Calculate shard dimensions that are divisible by chunk dimensions
1505+ if len (data_shape ) == 3 :
1506+ # For 3D data (time, y, x), ensure shard dimensions are divisible by chunks
1507+ shard_time = data_shape [0 ] # Keep full time dimension
1508+ shard_y = _calculate_shard_dimension (data_shape [1 ], chunks [1 ])
1509+ shard_x = _calculate_shard_dimension (data_shape [2 ], chunks [2 ])
1510+ shards = (shard_time , shard_y , shard_x )
1511+ print (
1512+ f" 🔧 Sharding config for { var } : data_shape={ data_shape } , chunks={ chunks } , shards={ shards } "
1513+ )
1514+ elif len (data_shape ) == 2 :
1515+ # For 2D data (y, x), ensure shard dimensions are divisible by chunks
1516+ shard_y = _calculate_shard_dimension (data_shape [0 ], chunks [0 ])
1517+ shard_x = _calculate_shard_dimension (data_shape [1 ], chunks [1 ])
1518+ shards = (shard_y , shard_x )
1519+ print (
1520+ f" 🔧 Sharding config for { var } : data_shape={ data_shape } , chunks={ chunks } , shards={ shards } "
1521+ )
1522+ else :
1523+ # For 1D data, use the full dimension
1524+ shards = (data_shape [0 ],)
1525+ print (
1526+ f" 🔧 Sharding config for { var } : data_shape={ data_shape } , chunks={ chunks } , shards={ shards } "
1527+ )
1528+
1529+ # Validate that shards are evenly divisible by chunks
1530+ for i , (shard_dim , chunk_dim ) in enumerate (zip (shards , chunks )):
1531+ if shard_dim % chunk_dim != 0 :
1532+ print (
1533+ f" ⚠️ Warning: Shard dimension { shard_dim } not evenly divisible by chunk dimension { chunk_dim } at axis { i } "
1534+ )
14631535
14641536 encoding [var ] = {
1465- "chunks" : ( spatial_chunk_aligned , spatial_chunk_aligned ) ,
1537+ "chunks" : chunks ,
14661538 "compressors" : compressor ,
1539+ "shards" : shards ,
14671540 }
14681541
14691542 # Add coordinate encoding
@@ -1618,6 +1691,46 @@ def _add_grid_mapping_variable(
16181691 print (f" Added grid_mapping attribute to { var_name } " )
16191692
16201693
1694+ def _calculate_shard_dimension (data_dim : int , chunk_dim : int ) -> int :
1695+ """
1696+ Calculate shard dimension that is evenly divisible by chunk dimension.
1697+
1698+ For Zarr v3 sharding with Dask, the shard dimension must be evenly
1699+ divisible by the chunk dimension to avoid checksum mismatches.
1700+
1701+ Parameters
1702+ ----------
1703+ data_dim : int
1704+ Size of the data dimension
1705+ chunk_dim : int
1706+ Size of the chunk dimension
1707+
1708+ Returns
1709+ -------
1710+ int
1711+ Shard dimension that is evenly divisible by chunk_dim
1712+ """
1713+ # If chunk is larger than data dimension, the effective chunk will be data_dim
1714+ # In this case, shard should also be data_dim to maintain divisibility
1715+ if chunk_dim >= data_dim :
1716+ return data_dim
1717+
1718+ # Calculate how many complete chunks fit in the data dimension
1719+ num_complete_chunks = data_dim // chunk_dim
1720+
1721+ # If we have at least 2 complete chunks, use a multiple of chunk_dim
1722+ if num_complete_chunks >= 2 :
1723+ # Use a shard size that's a multiple of chunk_dim
1724+ for multiplier in range (num_complete_chunks + 1 , 2 , - 1 ):
1725+ shard_size = multiplier * chunk_dim
1726+ if shard_size <= data_dim :
1727+ return shard_size
1728+
1729+ # Fallback: use the largest multiple of chunk_dim that fits
1730+ # If no complete chunks fit, use data_dim (this handles edge cases)
1731+ return num_complete_chunks * chunk_dim if num_complete_chunks > 0 else data_dim
1732+
1733+
16211734def _is_sentinel1 (dt : xr .DataTree ) -> bool :
16221735 """Return True if the input DataTree represents a Sentinel-1 product."""
16231736 stac_props = dt .attrs .get ("stac_discovery" , {}).get ("properties" , {})
0 commit comments