diff --git a/src/geozarr_toolkit/cli.py b/src/geozarr_toolkit/cli.py index 7a5d5de..312acb0 100644 --- a/src/geozarr_toolkit/cli.py +++ b/src/geozarr_toolkit/cli.py @@ -47,7 +47,7 @@ def create_parser() -> argparse.ArgumentParser: validate_parser.add_argument( "--conventions", nargs="+", - choices=["spatial", "proj", "multiscales"], + choices=["spatial", "proj", "multiscales", "geoemb"], help="Conventions to validate (auto-detected if not specified)", ) validate_parser.add_argument( @@ -233,6 +233,26 @@ def info_command(args: argparse.Namespace) -> int: print(f" - {asset}") print() + if "geoemb" in conventions: + print("Geoembeddings:") + print(f" Type: {attrs.get('geoemb:type')}") + print(f" Dimensions: {attrs.get('geoemb:dimensions')}") + print(f" Model: {attrs.get('geoemb:model')}") + source_data = attrs.get("geoemb:source_data", []) + print(f" Source data: {len(source_data)} reference(s)") + print(f" Data type: {attrs.get('geoemb:data_type')}") + if attrs.get("geoemb:gsd"): + print(f" GSD: {attrs['geoemb:gsd']}m") + if attrs.get("geoemb:spatial_layout"): + print(f" Spatial layout: {attrs['geoemb:spatial_layout']}") + if attrs.get("geoemb:chip_layout"): + cl = attrs["geoemb:chip_layout"] + print(f" Chip layout: {cl.get('layout_type')} {cl.get('chip_size')}") + if attrs.get("geoemb:quantization"): + q = attrs["geoemb:quantization"] + print(f" Quantization: {q.get('method')} (from {q.get('original_dtype')})") + print() + if args.verbose: print("Members:") for name, item in group.items(): diff --git a/src/geozarr_toolkit/conventions/__init__.py b/src/geozarr_toolkit/conventions/__init__.py index a728602..af12dac 100644 --- a/src/geozarr_toolkit/conventions/__init__.py +++ b/src/geozarr_toolkit/conventions/__init__.py @@ -34,6 +34,15 @@ """ from geozarr_toolkit.conventions.common import ZarrConventionMetadata +from geozarr_toolkit.conventions.geoemb import ( + GEOEMB_SCHEMA_URL, + GEOEMB_SPEC_URL, + GEOEMB_UUID, + ChipLayout, + Geoemb, + GeoembConventionMetadata, + Quantization, +) from geozarr_toolkit.conventions.multiscales import ( MULTISCALES_SCHEMA_URL, MULTISCALES_SPEC_URL, @@ -61,6 +70,9 @@ ) __all__ = [ + "GEOEMB_SCHEMA_URL", + "GEOEMB_SPEC_URL", + "GEOEMB_UUID", "MULTISCALES_SCHEMA_URL", "MULTISCALES_SPEC_URL", "MULTISCALES_UUID", @@ -70,12 +82,16 @@ "SPATIAL_SCHEMA_URL", "SPATIAL_SPEC_URL", "SPATIAL_UUID", + "ChipLayout", + "Geoemb", + "GeoembConventionMetadata", "GeoProj", "Multiscales", "MultiscalesAttrs", "MultiscalesConventionMetadata", "Proj", "ProjConventionMetadata", + "Quantization", "ScaleLevel", "Spatial", "SpatialConventionMetadata", diff --git a/src/geozarr_toolkit/conventions/geoemb.py b/src/geozarr_toolkit/conventions/geoemb.py new file mode 100644 index 0000000..c8d7b40 --- /dev/null +++ b/src/geozarr_toolkit/conventions/geoemb.py @@ -0,0 +1,211 @@ +""" +Models for the Geoembeddings Zarr Convention. + +This convention defines metadata for geospatial embedding groups stored in +Zarr format, including encoder model provenance, source data references, +and processing parameters. + +Specification: https://github.com/geo-embeddings/embeddings-zarr-convention +""" + +from __future__ import annotations + +from typing import Annotated, Final, Literal + +from pydantic import BaseModel, Field, model_validator + +from geozarr_toolkit.conventions.common import ZarrConventionMetadata, is_none + +GEOEMB_UUID: Final[Literal["61c12cc5-0e28-4056-999a-480cf3fb7e4c"]] = ( + "61c12cc5-0e28-4056-999a-480cf3fb7e4c" +) +GEOEMB_SCHEMA_URL: Final[str] = ( + "https://github.com/geo-embeddings/embeddings-zarr-convention/blob/main/schema.json" +) +GEOEMB_SPEC_URL: Final[str] = ( + "https://github.com/geo-embeddings/embeddings-zarr-convention/blob/main/README.md" +) + + +class GeoembConventionMetadata(ZarrConventionMetadata): + """Metadata for the geoemb: convention in zarr_conventions array.""" + + uuid: Literal["61c12cc5-0e28-4056-999a-480cf3fb7e4c"] = GEOEMB_UUID + name: Literal["geoemb:"] = "geoemb:" + schema_url: str = GEOEMB_SCHEMA_URL + spec_url: str = GEOEMB_SPEC_URL + description: str = ( + "Geoembeddings convention for geospatial embedding arrays with model provenance" + ) + + +class ChipLayout(BaseModel): + """ + Chip layout configuration for chip-type embeddings. + + Describes how the source imagery was divided into chips (patches). + + Attributes + ---------- + layout_type : str + Type of chip layout. Either "regular_grid" or "irregular". + chip_size : list[int] + Chip dimensions [height, width] in pixels. + stride : list[int] | None + Stride between chips [y, x]. Defaults to chip_size if not specified. + grid_id : str | None + Identifier for a predefined grid system. + grid_definition : str | None + URL to grid definition document. + """ + + layout_type: Literal["regular_grid", "irregular"] + chip_size: list[int] = Field(min_length=2, max_length=2) + stride: list[int] | None = Field(None, exclude_if=is_none) + grid_id: str | None = Field(None, exclude_if=is_none) + grid_definition: str | None = Field(None, exclude_if=is_none) + + model_config = {"extra": "forbid"} + + @model_validator(mode="after") + def validate_chip_size_positive(self) -> ChipLayout: + """Validate that chip_size values are positive.""" + if any(s < 1 for s in self.chip_size): + raise ValueError("chip_size values must be positive integers") + return self + + @model_validator(mode="after") + def validate_stride_length(self) -> ChipLayout: + """Validate that stride has exactly 2 elements when provided.""" + if self.stride is not None and len(self.stride) != 2: + raise ValueError("stride must have exactly 2 elements [y, x]") + return self + + +class ScaleScalar(BaseModel): + """ + Scalar scale for linear dequantization. + + Dequantize with: value = quantized * scale + offset. + """ + + type: Literal["scalar"] + scale: float + offset: float = 0.0 + + model_config = {"extra": "forbid"} + + +class ScaleArray(BaseModel): + """ + Per-pixel scale factors stored in a separate Zarr array. + + Dequantize with: value[..., y, x] = quantized[..., y, x] * array[..., y, x]. + Non-finite values (NaN, +inf) in the scale array indicate no-data pixels. + """ + + type: Literal["array"] + array_name: str + nodata: float | str | None = Field(None, exclude_if=is_none) + + model_config = {"extra": "forbid"} + + +Scale = Annotated[ScaleScalar | ScaleArray, Field(discriminator="type")] + + +class Quantization(BaseModel): + """ + Quantization details for compressed embeddings. + + Attributes + ---------- + method : str + Quantization method (e.g., "linear", "per_pixel_scale", + "product_quantization", "binary"). + original_dtype : str + Original data type before quantization (e.g., "float32"). + quantized_dtype : str | None + Data type after quantization (e.g., "int8"). + scale : ScaleScalar | ScaleArray | None + Scale parameters for dequantization. + link : str | None + URL to quantization codebook or lookup table. + """ + + method: str + original_dtype: str + quantized_dtype: str | None = Field(None, exclude_if=is_none) + scale: Scale | None = Field(None, exclude_if=is_none) + link: str | None = Field(None, exclude_if=is_none) + + model_config = {"extra": "forbid"} + + +class Geoemb(BaseModel): + """ + Geoembeddings convention attributes for a Zarr group. + + Attributes + ---------- + type : str + Type of embedding: "pixel" for per-pixel embeddings, + "chip" for image patch embeddings. Required. + dimensions : int + Dimensionality of the embedding vector. Required. + model : str + URL reference to the encoder model used to generate embeddings. Required. + source_data : list[str] + URL references to the source datasets. Required, at least one item. + data_type : str + Data type of stored embeddings (e.g., "float32", "int8"). Required. + gsd : float | None + Ground sample distance in meters. + chip_layout : ChipLayout | None + Chip layout configuration. Required when type is "chip". + quantization : Quantization | None + Compression/quantization details. + spatial_layout : str | None + Spatial organization scheme: "utm_zones" or "global". + build_version : str | None + Version of the software that built this store. + benchmark : list[str] | None + URLs to benchmark evaluation results. + """ + + type: Literal["pixel", "chip"] = Field(alias="geoemb:type") + dimensions: int = Field(alias="geoemb:dimensions", ge=1) + model: str = Field(alias="geoemb:model") + source_data: list[str] = Field(alias="geoemb:source_data", min_length=1) + data_type: str = Field(alias="geoemb:data_type") + gsd: float | None = Field(None, alias="geoemb:gsd", exclude_if=is_none) + chip_layout: ChipLayout | None = Field( + None, alias="geoemb:chip_layout", exclude_if=is_none + ) + quantization: Quantization | None = Field( + None, alias="geoemb:quantization", exclude_if=is_none + ) + spatial_layout: Literal["utm_zones", "global"] | None = Field( + None, alias="geoemb:spatial_layout", exclude_if=is_none + ) + build_version: str | None = Field( + None, alias="geoemb:build_version", exclude_if=is_none + ) + benchmark: list[str] | None = Field( + None, alias="geoemb:benchmark", exclude_if=is_none + ) + + model_config = { + "extra": "allow", + "populate_by_name": True, + "serialize_by_alias": True, + } + + @model_validator(mode="after") + def validate_chip_layout_required(self) -> Geoemb: + """Validate that chip_layout is provided when type is 'chip'.""" + if self.type == "chip" and self.chip_layout is None: + raise ValueError( + "geoemb:chip_layout is required when geoemb:type is 'chip'" + ) + return self diff --git a/src/geozarr_toolkit/helpers/__init__.py b/src/geozarr_toolkit/helpers/__init__.py index 61ec63d..ed0273b 100644 --- a/src/geozarr_toolkit/helpers/__init__.py +++ b/src/geozarr_toolkit/helpers/__init__.py @@ -19,6 +19,7 @@ from geozarr_toolkit.helpers.validation import ( detect_conventions, validate_attrs, + validate_geoemb, validate_group, validate_multiscales, validate_multiscales_structure, @@ -37,6 +38,7 @@ "from_geotransform", "from_rioxarray", "validate_attrs", + "validate_geoemb", "validate_group", "validate_multiscales", "validate_multiscales_structure", diff --git a/src/geozarr_toolkit/helpers/validation.py b/src/geozarr_toolkit/helpers/validation.py index 570f1b0..ccb1c46 100644 --- a/src/geozarr_toolkit/helpers/validation.py +++ b/src/geozarr_toolkit/helpers/validation.py @@ -12,9 +12,11 @@ from pydantic import ValidationError from geozarr_toolkit.conventions import ( + GEOEMB_UUID, MULTISCALES_UUID, PROJ_UUID, SPATIAL_UUID, + Geoemb, Multiscales, Proj, Spatial, @@ -100,6 +102,40 @@ def validate_multiscales(attrs: dict[str, Any]) -> tuple[bool, list[str]]: return False, [str(err) for err in e.errors()] +def validate_geoemb(attrs: dict[str, Any]) -> tuple[bool, list[str]]: + """ + Validate attributes against the geoemb: convention. + + Parameters + ---------- + attrs : dict + Attributes dictionary to validate. + + Returns + ------- + tuple[bool, list[str]] + (is_valid, list_of_errors) + + Example + ------- + ```python + is_valid, errors = validate_geoemb({ + "geoemb:type": "pixel", + "geoemb:dimensions": 768, + "geoemb:model": "https://huggingface.co/made-with-clay/Clay", + "geoemb:source_data": ["https://registry.opendata.aws/sentinel-2-l2a-cogs/"], + "geoemb:data_type": "float32", + }) + ``` + True + """ + try: + Geoemb(**attrs) + return True, [] + except ValidationError as e: + return False, [str(err) for err in e.errors()] + + def validate_zarr_conventions(attrs: dict[str, Any]) -> tuple[bool, list[str]]: """ Validate that zarr_conventions array is properly formatted. @@ -164,6 +200,10 @@ def detect_conventions(attrs: dict[str, Any]) -> list[str]: if "multiscales" in attrs: detected.append("multiscales") + # Check for geoemb convention + if "geoemb:type" in attrs: + detected.append("geoemb") + # Also check zarr_conventions array if "zarr_conventions" in attrs: for conv in attrs["zarr_conventions"]: @@ -178,6 +218,10 @@ def detect_conventions(attrs: dict[str, Any]) -> list[str]: uuid == MULTISCALES_UUID or name == "multiscales" ) and "multiscales" not in detected: detected.append("multiscales") + if ( + uuid == GEOEMB_UUID or name == "geoemb:" + ) and "geoemb" not in detected: + detected.append("geoemb") return detected @@ -234,6 +278,10 @@ def validate_group( is_valid, errors = validate_multiscales(attrs) results["multiscales"] = errors + if "geoemb" in conventions: + is_valid, errors = validate_geoemb(attrs) + results["geoemb"] = errors + # Always check zarr_conventions if present if "zarr_conventions" in attrs: is_valid, errors = validate_zarr_conventions(attrs) @@ -323,6 +371,10 @@ def validate_attrs( _, errors = validate_multiscales(attrs) results["multiscales"] = errors + if "geoemb" in conventions: + _, errors = validate_geoemb(attrs) + results["geoemb"] = errors + if "zarr_conventions" in attrs: _, errors = validate_zarr_conventions(attrs) results["zarr_conventions"] = errors diff --git a/tests/test_conventions/test_geoemb.py b/tests/test_conventions/test_geoemb.py new file mode 100644 index 0000000..9728633 --- /dev/null +++ b/tests/test_conventions/test_geoemb.py @@ -0,0 +1,123 @@ +"""Tests for the geoemb convention model.""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from geozarr_toolkit.conventions import ( + GEOEMB_UUID, + ChipLayout, + Geoemb, + GeoembConventionMetadata, + Quantization, +) + +MINIMAL_PIXEL: dict = { + "geoemb:type": "pixel", + "geoemb:dimensions": 64, + "geoemb:model": "https://arxiv.org/abs/2507.22291", + "geoemb:source_data": [ + "https://developers.google.com/earth-engine/datasets/catalog/GOOGLE_SATELLITE_EMBEDDING_V1_ANNUAL" + ], + "geoemb:data_type": "int8", +} + +MINIMAL_CHIP_LAYOUT: dict = {"layout_type": "regular_grid", "chip_size": [256, 256]} + + +class TestGeoembConventionMetadata: + def test_defaults(self) -> None: + meta = GeoembConventionMetadata() + assert meta.uuid == GEOEMB_UUID + assert meta.name == "geoemb:" + assert "geo-embeddings" in meta.schema_url + assert "geo-embeddings" in meta.spec_url + + def test_serialization(self) -> None: + data = GeoembConventionMetadata().model_dump(exclude_none=True) + assert data["uuid"] == GEOEMB_UUID + + +class TestGeoemb: + def test_minimal_pixel(self) -> None: + emb = Geoemb(**MINIMAL_PIXEL) + assert emb.type == "pixel" + assert emb.dimensions == 64 + assert emb.chip_layout is None + + def test_minimal_chip(self) -> None: + emb = Geoemb( + **{ + **MINIMAL_PIXEL, + "geoemb:type": "chip", + "geoemb:chip_layout": MINIMAL_CHIP_LAYOUT, + } + ) + assert emb.type == "chip" + assert emb.chip_layout.layout_type == "regular_grid" + assert emb.chip_layout.chip_size == [256, 256] + + def test_chip_requires_chip_layout(self) -> None: + with pytest.raises(ValidationError, match="geoemb:chip_layout is required"): + Geoemb(**{**MINIMAL_PIXEL, "geoemb:type": "chip"}) + + def test_optional_fields(self) -> None: + emb = Geoemb( + **{ + **MINIMAL_PIXEL, + "geoemb:gsd": 10.0, + "geoemb:spatial_layout": "utm_zones", + "geoemb:build_version": "1.0.0", + } + ) + assert emb.gsd == 10.0 + assert emb.spatial_layout == "utm_zones" + + def test_quantization_scalar_scale(self) -> None: + emb = Geoemb( + **{ + **MINIMAL_PIXEL, + "geoemb:quantization": { + "method": "linear", + "original_dtype": "float32", + "scale": {"type": "scalar", "scale": 0.01}, + }, + } + ) + assert emb.quantization.method == "linear" + assert emb.quantization.scale.type == "scalar" + + def test_quantization_array_scale(self) -> None: + emb = Geoemb( + **{ + **MINIMAL_PIXEL, + "geoemb:quantization": { + "method": "per_pixel_scale", + "original_dtype": "float32", + "scale": {"type": "array", "array_name": "scales"}, + }, + } + ) + assert emb.quantization.scale.type == "array" + assert emb.quantization.scale.array_name == "scales" + + def test_serialization_uses_aliases(self) -> None: + emb = Geoemb(**MINIMAL_PIXEL) + data = emb.model_dump(by_alias=True) + assert "geoemb:type" in data + assert "geoemb:dimensions" in data + + def test_extra_fields_allowed(self) -> None: + emb = Geoemb(**{**MINIMAL_PIXEL, "zarr_conventions": []}) + assert emb is not None + + +class TestChipLayout: + def test_invalid_chip_size_values(self) -> None: + with pytest.raises(ValidationError, match="positive"): + ChipLayout(layout_type="regular_grid", chip_size=[256, 0]) + + def test_stride_wrong_length(self) -> None: + with pytest.raises(ValidationError, match="2 elements"): + ChipLayout(layout_type="regular_grid", chip_size=[256, 256], stride=[256]) diff --git a/tests/test_helpers/test_validation.py b/tests/test_helpers/test_validation.py index 33f6723..72253b2 100644 --- a/tests/test_helpers/test_validation.py +++ b/tests/test_helpers/test_validation.py @@ -2,10 +2,11 @@ from __future__ import annotations -from geozarr_toolkit.conventions import MULTISCALES_UUID, SPATIAL_UUID +from geozarr_toolkit.conventions import GEOEMB_UUID, MULTISCALES_UUID, SPATIAL_UUID from geozarr_toolkit.helpers.validation import ( detect_conventions, validate_attrs, + validate_geoemb, validate_multiscales, validate_proj, validate_spatial, @@ -71,6 +72,31 @@ def test_missing_multiscales_key(self) -> None: assert "Missing 'multiscales'" in errors[0] +class TestValidateGeoemb: + """Tests for validate_geoemb.""" + + VALID_ATTRS: dict = { + "geoemb:type": "pixel", + "geoemb:dimensions": 64, + "geoemb:model": "https://arxiv.org/abs/2507.22291", + "geoemb:source_data": [ + "https://developers.google.com/earth-engine/datasets/catalog/GOOGLE_SATELLITE_EMBEDDING_V1_ANNUAL" + ], + "geoemb:data_type": "int8", + } + + def test_valid(self) -> None: + is_valid, errors = validate_geoemb(self.VALID_ATTRS) + assert is_valid + assert errors == [] + + def test_missing_required_field(self) -> None: + attrs = {k: v for k, v in self.VALID_ATTRS.items() if k != "geoemb:model"} + is_valid, errors = validate_geoemb(attrs) + assert not is_valid + assert len(errors) > 0 + + class TestValidateZarrConventions: """Tests for validate_zarr_conventions.""" @@ -127,6 +153,14 @@ def test_detect_from_zarr_conventions(self) -> None: detected = detect_conventions(attrs) assert "multiscales" in detected + def test_detect_geoemb_by_prefix(self) -> None: + attrs = {"geoemb:type": "pixel"} + assert "geoemb" in detect_conventions(attrs) + + def test_detect_geoemb_by_uuid(self) -> None: + attrs = {"zarr_conventions": [{"uuid": GEOEMB_UUID}]} + assert "geoemb" in detect_conventions(attrs) + def test_detect_all(self) -> None: """Test detecting all conventions.""" attrs = {