Skip to content

Commit 941e5e5

Browse files
Allow ingestion of enums (#187)
1 parent 8820c38 commit 941e5e5

16 files changed

Lines changed: 124 additions & 149 deletions

File tree

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## Added
11+
12+
- Support for ingestion of enum values into datasets
13+
1014
## [0.34.0] - 2025-04-15
1115

1216
## Added

tilebox-datasets/tests/data/datapoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,10 @@ def example_datapoints(draw: DrawFn, generated_fields: bool = False, missing_fie
7575
maybe_none = none() if missing_fields else one_of()
7676

7777
return ExampleDatapoint(
78-
time=(draw(datetime_messages()) if generated_fields else None),
78+
time=draw(datetime_messages()),
7979
id=(draw(uuid_messages()) if generated_fields else None),
8080
ingestion_time=(draw(datetime_messages()) if generated_fields else None),
81-
# geometry=, # skip for now
81+
geometry=draw(geometry_messages()),
8282
some_string=draw(text(alphabet=string.ascii_letters + string.digits, min_size=1, max_size=10) | maybe_none),
8383
some_int=draw(integers(min_value=1, max_value=100) | maybe_none),
8484
some_double=draw(floats(min_value=1.0, max_value=100.0) | maybe_none),

tilebox-datasets/tests/protobuf_conversion/test_protobuf_xarray.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
from hypothesis.strategies import lists
66
from numpy.testing import assert_array_almost_equal, assert_array_equal
77
from pandas import to_datetime
8-
from shapely import MultiPolygon, Polygon
8+
from shapely import MultiPolygon, Polygon, from_wkb
99
from xarray.testing import assert_equal
1010

1111
from tests.data.datapoint import datapoint_pages, datapoints, example_datapoints
1212
from tests.example_dataset.example_dataset_pb2 import ExampleDatapoint
1313
from tilebox.datasets.data.datapoint import Datapoint, DatapointPage
1414
from tilebox.datasets.data.time_interval import timestamp_to_datetime, us_to_datetime
15-
from tilebox.datasets.datasetsv1.well_known_types_pb2 import ProcessingLevel
1615
from tilebox.datasets.protobuf_conversion.protobuf_xarray import (
1716
MessageToXarrayConverter,
1817
TimeseriesToXarrayConverter,
@@ -45,6 +44,7 @@ def test_convert_datapoint(datapoint: ExampleDatapoint) -> None: # noqa: PLR091
4544
assert us_to_datetime(to_datetime(ingestion_time, utc=True).value // 1000) == timestamp_to_datetime(
4645
datapoint.ingestion_time
4746
)
47+
assert dataset.geometry.item() == from_wkb(datapoint.geometry.wkb)
4848

4949
dataset = dataset.isel(time=0) # select the only datapoint in the dataset
5050
assert dataset.some_string.item() == datapoint.some_string
@@ -72,8 +72,7 @@ def test_convert_datapoint(datapoint: ExampleDatapoint) -> None: # noqa: PLR091
7272
)
7373

7474
assert isinstance(dataset.some_geometry.item(), Polygon | MultiPolygon)
75-
expected_level = {v: k for k, v in ProcessingLevel.items()}[datapoint.some_enum].removeprefix("PROCESSING_LEVEL_")
76-
assert dataset.some_enum.item() == expected_level
75+
assert dataset.some_enum.item() == datapoint.some_enum
7776

7877
assert list(dataset.some_repeated_string.to_numpy()) == list(datapoint.some_repeated_string)
7978
assert_array_equal(dataset.some_repeated_int.to_numpy(), datapoint.some_repeated_int)

tilebox-datasets/tests/protobuf_conversion/test_to_protobuf.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@ def test_xarray_dataset_to_protobuf_messages(messages: list[ExampleDatapoint]) -
2323
assert not converted_message.HasField(field.name)
2424
continue
2525

26-
if field.name == "some_enum": # enum ingestion not implemented yet
27-
continue
28-
2926
assert getattr(expected_message, field.name) == getattr(converted_message, field.name), (
3027
f"Field {field.name} mismatch"
3128
)

tilebox-datasets/tests/test_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ def test_find_datapoint() -> None:
110110

111111
if not skip_data:
112112
assert datapoint.granule_name.item() == "S2A_MSIL1C_20220713T002201_N0400_R102_T08XNS_20220713T015332.SAFE"
113-
assert datapoint.processing_level.item() == "L1C"
113+
processing_level = datapoint.processing_level.item()
114+
assert datapoint.processing_level.attrs["names"][processing_level] == "L1C"
114115
assert datapoint.copernicus_id.item() == "65505f82-76dd-5e85-b947-a6c879e07446"
115116
assert isinstance(datapoint.geometry.item(), Polygon)
116117
else:

tilebox-datasets/tilebox/datasets/protobuf_conversion/field_types.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,25 @@ def to_proto(self, value: Any) -> bool:
7878
return bool(value)
7979

8080

81+
class EnumField(ProtobufFieldType):
82+
def __init__(self, name_lookup: dict[int, str]) -> None:
83+
super().__init__(np.uint8) # we support up to 256 different enum values for now
84+
self._values_to_name = name_lookup
85+
self._names_to_value = {name: value for value, name in name_lookup.items()}
86+
87+
def from_proto(self, value: ProtoFieldValue) -> int:
88+
if not isinstance(value, int):
89+
raise TypeError(f"Expected int message but got {type(value)}")
90+
return value # we don't parse the value when loading, to avoid having huge arrays of strings
91+
92+
def to_proto(self, value: str | int) -> int:
93+
if isinstance(value, (str, np.str_)):
94+
return self._names_to_value[value]
95+
if int(value) not in self._values_to_name:
96+
raise ValueError(f"Invalid enum value {value}") # during ingestion, we can raise an error here
97+
return value
98+
99+
81100
class TimestampField(ProtobufFieldType):
82101
def __init__(self) -> None:
83102
super().__init__("datetime64[ns]")
@@ -252,10 +271,47 @@ def infer_field_type(field: FieldDescriptor) -> ProtobufFieldType:
252271

253272
return _MESSAGE_NAMES_TO_FIELDS[message_name]
254273

274+
if field.type == FieldDescriptor.TYPE_ENUM:
275+
return EnumField(enum_mapping_from_field_descriptor(field))
276+
255277
if field.type == FieldDescriptor.TYPE_BOOL:
256278
return BoolField() # special handling, since we need to convert numpy bools to python bools
257279

258280
if field.type not in _PROTOBUF_TYPE_TO_NUMPY_TYPE:
259281
raise ValueError(f"Unsupported field type {field.type}")
260282

261283
return ProtobufFieldType(_PROTOBUF_TYPE_TO_NUMPY_TYPE[field.type])
284+
285+
286+
def enum_mapping_from_field_descriptor(field: FieldDescriptor) -> dict[int, str]:
287+
"""Create a mapping from enum values to their names.
288+
289+
Args:
290+
field: The field descriptor to create the mapping for. Must be of type FieldDescriptor.TYPE_ENUM.
291+
"""
292+
if field.type != FieldDescriptor.TYPE_ENUM:
293+
raise ValueError("Expected field to be of type FieldDescriptor.TYPE_ENUM")
294+
295+
# remove the enum type prefix from the enum values
296+
# e.g. FLIGHT_DIRECTION_ASCENDING of the FlightDirection enum will result in a value of ASCENDING
297+
enum_type_prefix = _camel_to_uppercase(field.enum_type.name) + "_"
298+
return {
299+
v.number: str(v.name).removeprefix(enum_type_prefix)
300+
for v in field.enum_type.values # noqa: PD011 # enum_type is not a numpy array, even though ruff thinks it is
301+
}
302+
303+
304+
def _camel_to_uppercase(name: str) -> str:
305+
"""Convert a camelCase name to an UPPER_CASE name.
306+
307+
Args:
308+
name: The name to convert.
309+
310+
Returns:
311+
The converted name.
312+
313+
Examples:
314+
>>> _camel_to_uppercase("ProcessingLevel")
315+
'PROCESSING_LEVEL'
316+
"""
317+
return "".join(["_" + c.lower() if c.isupper() else c for c in name]).lstrip("_").upper()

tilebox-datasets/tilebox/datasets/protobuf_conversion/protobuf_xarray.py

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414

1515
from tilebox.datasets.data.datapoint import Datapoint, DatapointPage
1616
from tilebox.datasets.message_pool import get_message_type
17-
from tilebox.datasets.protobuf_conversion.field_types import ProtobufFieldType, ProtoFieldValue, infer_field_type
17+
from tilebox.datasets.protobuf_conversion.field_types import (
18+
EnumField,
19+
ProtobufFieldType,
20+
ProtoFieldValue,
21+
enum_mapping_from_field_descriptor,
22+
infer_field_type,
23+
)
1824

1925
AnyMessage = TypeVar("AnyMessage", bound=Message)
2026

@@ -375,21 +381,13 @@ def __init__(self, field_name: str, enum_names: dict[int, str]) -> None:
375381
Args:
376382
field_name: The name of enum field in the protobuf message
377383
"""
378-
super().__init__(field_name, ProtobufFieldType(np.uint8)) # we support up to 256 different enum values for now
384+
super().__init__(field_name, EnumField(enum_names))
379385
self._enum_names = enum_names
380386

381387
def finalize(self, dataset: xr.Dataset, count: int, dimension_names: tuple[str, ...]) -> str | None:
382388
field_name = super().finalize(dataset, count, dimension_names)
383389
if field_name is not None:
384-
# convert the numeric enum values to the corresponding string names
385-
int_values = dataset[field_name]
386-
dataset[field_name] = (
387-
int_values.dims,
388-
np.vectorize(lambda i: self._enum_names.get(i, ""))(int_values.values),
389-
)
390-
dataset[field_name].attrs["enum_dict"] = ";".join(
391-
f"{name}: {value}" for value, name in self._enum_names.items()
392-
)
390+
dataset[field_name].attrs["names"] = self._enum_names
393391
return field_name
394392

395393

@@ -418,22 +416,6 @@ def _create_field_converters(message: Message, buffer_size: int) -> dict[str, _F
418416
return converters
419417

420418

421-
def _camel_to_uppercase(name: str) -> str:
422-
"""Convert a camelCase name to an UPPER_CASE name.
423-
424-
Args:
425-
name: The name to convert.
426-
427-
Returns:
428-
The converted name.
429-
430-
Examples:
431-
>>> _camel_to_uppercase("ProcessingLevel")
432-
'PROCESSING_LEVEL'
433-
"""
434-
return "".join(["_" + c.lower() if c.isupper() else c for c in name]).lstrip("_").upper()
435-
436-
437419
def _create_field_converter(field: FieldDescriptor) -> _FieldConverter:
438420
"""
439421
Create a field converter for the given protobuf field descriptor.
@@ -450,12 +432,7 @@ def _create_field_converter(field: FieldDescriptor) -> _FieldConverter:
450432
if field.label == FieldDescriptor.LABEL_REPEATED:
451433
raise NotImplementedError("Repeated enum fields are not supported")
452434

453-
# remove the enum type prefix from the enum values
454-
enum_type_prefix = _camel_to_uppercase(field.enum_type.name) + "_"
455-
return _EnumFieldConverter(
456-
field.name,
457-
{v.number: str(v.name).removeprefix(enum_type_prefix) for v in field.enum_type.values},
458-
)
435+
return _EnumFieldConverter(field.name, enum_mapping_from_field_descriptor(field))
459436

460437
field_type = infer_field_type(field)
461438
if field.label == FieldDescriptor.LABEL_OPTIONAL: # simple fields (in proto3 every simple field is optional)

tilebox-datasets/tilebox/datasets/protobuf_conversion/to_protobuf.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,6 @@ def to_messages( # noqa: C901, PLR0912
7979
field_lengths[len(values)].append(field_name) # to validate all fields have the same length
8080

8181
descriptor = field_descriptors_by_name[field_name]
82-
if descriptor.type == FieldDescriptor.TYPE_ENUM:
83-
continue # skip enums, not supported for now in ingestion
84-
8582
field_type = infer_field_type(descriptor)
8683
if isinstance(field_type, GeobufField):
8784
continue # legacy geometry type, ingestion is only supported for the new geometry type

tilebox-datasets/tilebox/datasets/sync/timeseries.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def __init__(self, dataset: TimeseriesDataset, info: CollectionInfo) -> None:
139139

140140
def __repr__(self) -> str:
141141
"""Human readable representation of the collection."""
142-
return repr(self._info)
142+
return repr(self.info())
143143

144144
@property
145145
def name(self) -> str:

tilebox-storage/tests/storage_data.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -34,35 +34,7 @@ def ers_granules(draw: DrawFn, ensure_quicklook: bool = False) -> ASFStorageGran
3434
quicklook = f"{_ASF_URL}/BROWSE/E{platform}/{granule_name}.jpg"
3535
urls = StorageURLs(f"{_ASF_URL}/{level}/E{platform}/{file_name}.zip", quicklook)
3636

37-
return ASFStorageGranule(time, granule_name, level, "ASF", file_size, md5sum, urls)
38-
39-
40-
@composite
41-
def s1_granules(draw: DrawFn) -> ASFStorageGranule:
42-
"""Generate a realistic-looking random Sentinel 1 granule."""
43-
level = "RAW"
44-
platform = draw(one_of(just("A"), just("B")))
45-
acquisition_mode = draw(one_of(*(just(am) for am in ["IW", "EW", *[f"S{i + 1}" for i in range(6)]])))
46-
start = draw(datetimes(min_value=datetime(2014, 6, 1), max_value=datetime(2050, 1, 1), timezones=just(None)))
47-
stop = draw(datetimes(min_value=datetime(2014, 6, 1), max_value=datetime(2050, 1, 1), timezones=just(None)))
48-
orbit = draw(integers(min_value=1, max_value=999_999))
49-
random_1 = draw(text(alphabet=string.hexdigits, min_size=6, max_size=6)).upper()
50-
random_2 = draw(text(alphabet=string.hexdigits, min_size=4, max_size=4)).upper()
51-
file_size = draw(integers(min_value=10_000, max_value=999_999_999))
52-
md5sum = draw(text(alphabet=string.hexdigits, min_size=32, max_size=32))
53-
54-
granule_name = (
55-
f"S1{platform}_{acquisition_mode}_{level}__0SDV_{start:%Y%m%dT%H%M%S}_{stop:%Y%m%dT%H%M%S}_"
56-
f"{orbit:06d}_S{random_1}_{random_2}"
57-
)
58-
59-
urls = StorageURLs(f"{_ASF_URL}/{level}/S{platform}/{granule_name}.zip", None)
60-
return ASFStorageGranule(start, granule_name, level, "ASF", file_size, md5sum, urls)
61-
62-
63-
@composite
64-
def asf_granules(draw: DrawFn) -> ASFStorageGranule:
65-
return draw(one_of(ers_granules(), s1_granules()))
37+
return ASFStorageGranule(time, granule_name, file_size, md5sum, urls)
6638

6739

6840
@composite

0 commit comments

Comments
 (0)