Skip to content

Commit 5be1f10

Browse files
authored
Merge pull request #154 from d-v-b/feat/scale-offset-codecs
feat: implement scale-offset and data type casting via codecs
2 parents 0d4a880 + 2fcf0fb commit 5be1f10

20 files changed

Lines changed: 14121 additions & 13160 deletions

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ requires-python = ">=3.12"
2929
dependencies = [
3030
"pydantic-zarr>=0.8.0",
3131
"pydantic>=2.12",
32-
"zarr>=3.1.1",
32+
"zarr[cast-value-rs]>=3.2.0",
3333
"xarray>=2025.7.1",
3434
"dask[array,distributed]>=2026.1.0",
3535
"numpy>=2.3.1",

src/eopf_geozarr/cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,6 +1165,11 @@ def add_s2_optimization_commands(subparsers: argparse._SubParsersAction) -> None
11651165
scale-offset encoding will be re-saved as the decoded data type, i.e. floating point values.
11661166
""",
11671167
)
1168+
s2_parser.add_argument(
1169+
"--experimental-scale-offset-codec",
1170+
action="store_true",
1171+
help="Push CF scale-offset encoding into zarr codec pipeline instead of decoding to float.",
1172+
)
11681173
s2_parser.add_argument(
11691174
"--dask-cluster",
11701175
action="store_true",
@@ -1197,6 +1202,7 @@ def convert_s2_optimized_command(args: argparse.Namespace) -> None:
11971202
compression_level=args.compression_level,
11981203
validate_output=not args.skip_validation,
11991204
keep_scale_offset=args.keep_scale_offset,
1205+
experimental_scale_offset_codec=args.experimental_scale_offset_codec,
12001206
)
12011207

12021208
log.info("✅ S2 optimization completed", output_path=args.output_path)

src/eopf_geozarr/codecs/__init__.py

Whitespace-only changes.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""
2+
CF-to-zarr-codec helper for the `scale_offset` codec.
3+
4+
The `scale_offset` codec itself ships with zarr-python >= 3.2.0
5+
(`zarr.codecs.ScaleOffset`); this module only provides the small mapping from
6+
CF-convention `scale_factor` / `add_offset` attributes to `ScaleOffset`
7+
constructor arguments.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
from zarr.codecs import ScaleOffset
13+
14+
15+
def scale_offset_from_cf(*, scale_factor: float, add_offset: float) -> ScaleOffset:
16+
"""
17+
Convert CF-convention scale_factor and add_offset to a ScaleOffset codec.
18+
19+
CF convention: unpacked = packed * scale_factor + add_offset
20+
21+
ScaleOffset convention:
22+
encode: out = (in - offset) * scale
23+
decode: out = (in / scale) + offset
24+
25+
To match CF: offset = add_offset, scale = 1 / scale_factor.
26+
"""
27+
return ScaleOffset(offset=add_offset, scale=1.0 / scale_factor)

src/eopf_geozarr/conversion/geozarr.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,7 +1479,14 @@ def _create_encoding(
14791479
else:
14801480
chunking = (min(spatial_chunk, data_shape[-1]),)
14811481

1482-
encoding[var] = {"compressors": [compressor], "chunks": chunking}
1482+
var_encoding: XarrayEncodingJSON = {
1483+
"compressors": [compressor],
1484+
"chunks": chunking,
1485+
}
1486+
fv = utils.explicit_fill_value(ds[var])
1487+
if fv is not utils.UNSET:
1488+
var_encoding["fill_value"] = fv
1489+
encoding[var] = var_encoding
14831490

14841491
# Add coordinate encoding
14851492
for coord in ds.coords:
@@ -1565,11 +1572,15 @@ def _create_geozarr_encoding(
15651572
axis=i,
15661573
)
15671574

1568-
encoding[var] = {
1575+
var_encoding: XarrayEncodingJSON = {
15691576
"chunks": chunks,
15701577
"compressors": compressor,
15711578
"shards": shards,
15721579
}
1580+
fv = utils.explicit_fill_value(ds[var])
1581+
if fv is not utils.UNSET:
1582+
var_encoding["fill_value"] = fv
1583+
encoding[var] = var_encoding
15731584

15741585
# Add coordinate encoding
15751586
for coord in ds.coords:

src/eopf_geozarr/conversion/utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Utility functions for GeoZarr conversion."""
22

3+
from typing import Any
4+
35
import numpy as np
46
import rasterio # noqa: F401 # Import to enable .rio accessor
57
import structlog
@@ -8,6 +10,39 @@
810
log = structlog.get_logger()
911

1012

13+
# Sentinel: distinguish "no explicit fill_value" from a legitimate `None`.
14+
UNSET: Any = object()
15+
16+
17+
def explicit_fill_value(var: xr.DataArray) -> Any:
18+
"""Pick a zarr-level `fill_value` for `var` based on its source `_FillValue`.
19+
20+
Different xarray versions infer different on-disk fill values when the
21+
encoding dict doesn't pin it: older xarray defaults floats to 0.0; newer
22+
xarray honours the source `_FillValue`. Setting `fill_value` explicitly
23+
via this helper removes that degree of freedom so the on-disk metadata is
24+
stable across xarray versions.
25+
26+
Returns
27+
-------
28+
object
29+
The value to assign to `encoding["fill_value"]`. The sentinel `UNSET`
30+
is returned when the source has no `_FillValue` (caller should leave
31+
the encoding entry alone). For non-finite floats, returns the
32+
JSON-canonical string form (`"NaN"` / `"Infinity"` / `"-Infinity"`)
33+
that zarr-python serialises.
34+
"""
35+
source_fill = var.encoding.get("_FillValue")
36+
if source_fill is None:
37+
return UNSET
38+
fill_arr = np.asarray(source_fill)
39+
if np.issubdtype(fill_arr.dtype, np.floating) and not np.isfinite(fill_arr):
40+
if np.isnan(fill_arr):
41+
return "NaN"
42+
return "Infinity" if fill_arr > 0 else "-Infinity"
43+
return source_fill
44+
45+
1146
def downsample_2d_array(
1247
source_data: np.ndarray,
1348
target_height: int,

src/eopf_geozarr/s2_optimization/s2_converter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def convert_s2_optimized(
186186
compression_level: int,
187187
validate_output: bool,
188188
keep_scale_offset: bool,
189+
experimental_scale_offset_codec: bool = False,
189190
max_retries: int = 3,
190191
) -> xr.DataTree:
191192
"""
@@ -199,6 +200,7 @@ def convert_s2_optimized(
199200
compression_level: Compression level 1-9
200201
validate_output: Whether to validate the output
201202
keep_scale_offset: Whether to preserve scale-offset encoding of the source data.
203+
experimental_scale_offset_codec: Push CF scale-offset into zarr codec pipeline.
202204
max_retries: Maximum number of retries for network operations
203205
204206
Returns:
@@ -234,6 +236,7 @@ def convert_s2_optimized(
234236
enable_sharding=enable_sharding,
235237
crs=crs,
236238
keep_scale_offset=keep_scale_offset,
239+
experimental_scale_offset_codec=experimental_scale_offset_codec,
237240
)
238241

239242
log.info("Created multiscale pyramids", num_groups=len(datasets))

src/eopf_geozarr/s2_optimization/s2_multiscale.py

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
from dask.array import from_delayed
1616
from pydantic.experimental.missing_sentinel import MISSING
1717
from pyproj import CRS
18+
from zarr.codecs import CastValue
1819
from zarr_cm import geo_proj
1920
from zarr_cm import multiscales as multiscales_cm
2021
from zarr_cm import spatial as spatial_cm
2122

23+
from eopf_geozarr.conversion import utils
2224
from eopf_geozarr.conversion.fs_utils import sanitize_dataset_attributes
2325
from eopf_geozarr.conversion.geozarr import (
2426
_create_tile_matrix_limits,
@@ -84,8 +86,14 @@ def _coarsen_variable(var_name: str, var_data: xr.DataArray, factor: int) -> xr.
8486
else:
8587
raise ValueError(f"Unknown variable type {var_type}")
8688

87-
result.encoding = var_data.encoding
88-
return result.astype(var_data.dtype)
89+
# `xr.DataArray.astype` clears `.encoding`, so we capture it first and
90+
# restore it on the cast result. Without this, downstream code that
91+
# inspects encoding (e.g. to push CF scale-offset into a codec pipeline)
92+
# would see an empty encoding on every coarsened level.
93+
encoding = var_data.encoding
94+
result = result.astype(var_data.dtype)
95+
result.encoding = encoding
96+
return result
8997

9098

9199
def inject_missing_bands(
@@ -170,6 +178,7 @@ def create_multiscale_from_datatree(
170178
spatial_chunk: int,
171179
crs: CRS | None = None,
172180
keep_scale_offset: bool,
181+
experimental_scale_offset_codec: bool = False,
173182
) -> dict[str, dict]:
174183
"""
175184
Create multiscale versions preserving original structure.
@@ -239,11 +248,17 @@ def create_multiscale_from_datatree(
239248
spatial_chunk=spatial_chunk,
240249
enable_sharding=enable_sharding,
241250
keep_scale_offset=keep_scale_offset,
251+
experimental_scale_offset_codec=experimental_scale_offset_codec,
242252
)
243-
# convert float64 arrays to float32
253+
# convert float64 arrays to float32. `xr.DataArray.astype` clears
254+
# encoding, so we capture and restore it — downstream pyramid
255+
# levels are coarsened from this dataset and rely on the encoding
256+
# to drive CF packing / codec filter generation.
244257
for data_var in dataset.data_vars:
245258
if dataset[data_var].dtype in (np.dtype("<f8"), np.dtype(">f8")):
259+
var_encoding = dataset[data_var].encoding
246260
dataset[data_var] = dataset[data_var].astype("float32")
261+
dataset[data_var].encoding = var_encoding
247262
# Clear _FillValue from the DataArray's own encoding to prevent
248263
# xarray from raising "Zarr does not support _FillValue in encoding".
249264
if not keep_scale_offset:
@@ -300,6 +315,7 @@ def create_multiscale_from_datatree(
300315
spatial_chunk=spatial_chunk,
301316
enable_sharding=enable_sharding,
302317
keep_scale_offset=keep_scale_offset,
318+
experimental_scale_offset_codec=experimental_scale_offset_codec,
303319
)
304320

305321
# Strip _FillValue from DataArray encoding for downsampled levels too
@@ -343,6 +359,7 @@ def create_measurements_encoding(
343359
spatial_chunk: int,
344360
enable_sharding: bool = True,
345361
keep_scale_offset: bool = True,
362+
experimental_scale_offset_codec: bool = False,
346363
) -> dict[str, XarrayDataArrayEncoding]:
347364
"""
348365
Create optimized encoding for a pyramid level with advanced chunking and sharding.
@@ -390,7 +407,48 @@ def create_measurements_encoding(
390407
# Forward-propagate the existing encoding, minus keys that should be omitted
391408
keep_keys = XARRAY_ENCODING_KEYS - {"compressors", "shards", "chunks"}
392409

393-
if not keep_scale_offset:
410+
if experimental_scale_offset_codec and not keep_scale_offset:
411+
# Push CF scale-offset into the zarr codec pipeline instead of
412+
# decoding to float. The data stays as packed integers on disk,
413+
# but zarr transparently decodes on read.
414+
scale_factor = var_data.encoding.get("scale_factor")
415+
add_offset = var_data.encoding.get("add_offset")
416+
packed_dtype = var_data.encoding.get("dtype")
417+
418+
if scale_factor is not None and add_offset is not None and packed_dtype is not None:
419+
from eopf_geozarr.codecs.scale_offset import scale_offset_from_cf
420+
421+
so_codec = scale_offset_from_cf(
422+
scale_factor=float(scale_factor), add_offset=float(add_offset)
423+
)
424+
# CastValue refuses to cast NaN to integer without an explicit
425+
# mapping, so we need a packed-dtype sentinel for NaN. Prefer
426+
# the source's existing `_FillValue` (it already encodes the
427+
# "no data" semantic via xarray's CF mask_and_scale loop), and
428+
# fall back to the dtype's lowest representable integer.
429+
packed_np_dtype = np.dtype(packed_dtype)
430+
source_fill = var_data.encoding.get("_FillValue")
431+
if source_fill is not None:
432+
nan_sentinel = int(source_fill)
433+
else:
434+
nan_sentinel = int(np.iinfo(packed_np_dtype).min)
435+
cv_codec = CastValue(
436+
data_type=packed_np_dtype.name,
437+
rounding="nearest-even",
438+
scalar_map={
439+
"encode": [("NaN", nan_sentinel)],
440+
"decode": [(nan_sentinel, "NaN")],
441+
},
442+
)
443+
var_encoding["filters"] = (so_codec, cv_codec)
444+
445+
# Strip CF keys and `filters` from `keep_keys` — the codecs handle
446+
# encoding/decoding now, and we don't want the forward-propagation
447+
# loop below to overwrite our freshly-set filters with whatever was
448+
# on the source variable.
449+
keep_keys = keep_keys - CF_SCALE_OFFSET_KEYS - {"_FillValue", "filters"}
450+
var_encoding["fill_value"] = "NaN"
451+
elif not keep_scale_offset:
394452
# When stripping scale/offset, also strip _FillValue since the original
395453
# _FillValue is in raw integer units and meaningless for decoded float data.
396454
keep_keys = keep_keys - CF_SCALE_OFFSET_KEYS - {"_FillValue"}
@@ -399,7 +457,7 @@ def create_measurements_encoding(
399457
# xarray's zarr backend uses "fill_value" (no underscore) in encoding
400458
# to set the zarr-level fill value, distinct from "_FillValue" which
401459
# controls CF-convention attribute masking.
402-
var_encoding["fill_value"] = float("nan")
460+
var_encoding["fill_value"] = "NaN"
403461

404462
for key in keep_keys:
405463
if key in var_data.encoding:
@@ -805,9 +863,15 @@ def create_original_encoding(dataset: xr.Dataset) -> dict[str, XarrayDataArrayEn
805863
var_data = dataset.data_vars[var_name]
806864
var_encoding: XarrayDataArrayEncoding = {}
807865
var_encoding["compressors"] = (compressor,)
808-
for key in XARRAY_ENCODING_KEYS - {"compressors"}:
866+
for key in XARRAY_ENCODING_KEYS - {"compressors", "fill_value"}:
809867
if key in var_data.encoding:
810868
var_encoding[key] = var_data.encoding[key] # type: ignore[literal-required]
869+
# Set the zarr-level `fill_value` explicitly rather than letting xarray
870+
# decide — different xarray versions infer different defaults from the
871+
# variable's `_FillValue`. See `explicit_fill_value` for the rationale.
872+
fv = utils.explicit_fill_value(var_data)
873+
if fv is not utils.UNSET:
874+
var_encoding["fill_value"] = fv
811875
if len(set(var_data.encoding.keys()) - XARRAY_ENCODING_KEYS) > 0:
812876
log.warning(
813877
"Unknown encoding keys in %s: %s",

src/eopf_geozarr/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class TileMatrixLimitJSON(TypedDict):
1212

1313

1414
class XarrayEncodingJSON(TypedDict):
15+
fill_value: NotRequired[object]
1516
chunks: NotRequired[tuple[int, ...]]
1617
compressors: Any
1718
shards: NotRequired[Any]

0 commit comments

Comments
 (0)