Skip to content

Commit 98d920b

Browse files
corentinmusardlukasbindreiter
authored andcommitted
Add query jobs (#189)
* Add list jobs * Use QueryJobs * address comments * handle (str, str) case better
1 parent 8f089f7 commit 98d920b

21 files changed

Lines changed: 556 additions & 247 deletions

File tree

CHANGELOG.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [0.36.0] - 2025-05-27
11+
12+
## Added
13+
14+
- Added `query` method to `JobClient` to query jobs in a given temporal extent and filter by automation id
15+
1016
## [0.35.0] - 2025-04-29
1117

1218
## Added
@@ -161,7 +167,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
161167
- Released packages: `tilebox-datasets`, `tilebox-workflows`, `tilebox-storage`, `tilebox-grpc`
162168

163169

164-
[Unreleased]: https://github.com/tilebox/tilebox-python/compare/v0.35.0...HEAD
170+
[Unreleased]: https://github.com/tilebox/tilebox-python/compare/v0.36.0...HEAD
171+
[0.36.0]: https://github.com/tilebox/tilebox-python/compare/v0.35.0...v0.36.0
165172
[0.35.0]: https://github.com/tilebox/tilebox-python/compare/v0.34.0...v0.35.0
166173
[0.34.0]: https://github.com/tilebox/tilebox-python/compare/v0.33.1...v0.34.0
167174
[0.33.1]: https://github.com/tilebox/tilebox-python/compare/v0.33.0...v0.33.1

tilebox-datasets/tests/data/datapoint.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
AnyMessage,
3636
Datapoint,
3737
DatapointInterval,
38+
DatapointIntervalLike,
3839
DatapointPage,
3940
IngestResponse,
4041
QueryResultPage,
@@ -59,6 +60,19 @@ def datapoint_intervals(draw: DrawFn) -> DatapointInterval:
5960
return DatapointInterval(start, end, start_exclusive, end_inclusive)
6061

6162

63+
@composite
64+
def datapoint_intervals_like(draw: DrawFn) -> DatapointIntervalLike:
65+
"""A hypothesis strategy for generating random datapoint intervals"""
66+
interval = draw(datapoint_intervals())
67+
return draw(
68+
one_of(
69+
just(interval),
70+
just((str(interval.start_id), str(interval.end_id))),
71+
just((interval.start_id, interval.end_id)),
72+
)
73+
)
74+
75+
6276
@composite
6377
def example_datapoints(draw: DrawFn, generated_fields: bool = False, missing_fields: bool = False) -> ExampleDatapoint:
6478
"""

tilebox-datasets/tests/data/test_datapoint.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from tests.data.datapoint import (
44
anys,
55
datapoint_intervals,
6+
datapoint_intervals_like,
67
datapoint_pages,
78
datapoints,
89
ingest_datapoints_responses,
@@ -13,6 +14,7 @@
1314
AnyMessage,
1415
Datapoint,
1516
DatapointInterval,
17+
DatapointIntervalLike,
1618
DatapointPage,
1719
IngestResponse,
1820
QueryResultPage,
@@ -25,6 +27,19 @@ def test_datapoint_intervals_to_message_and_back(interval: DatapointInterval) ->
2527
assert DatapointInterval.from_message(interval.to_message()) == interval
2628

2729

30+
@given(datapoint_intervals_like())
31+
def test_parse_datapoint_interval_from_tuple(interval: DatapointIntervalLike) -> None:
32+
parsed = DatapointInterval.parse(interval)
33+
34+
if isinstance(interval, DatapointInterval):
35+
assert parsed == interval, f"Failed parsing interval from {interval}"
36+
assert parsed.start_exclusive == interval.start_exclusive
37+
assert parsed.end_inclusive == interval.end_inclusive
38+
else:
39+
assert not parsed.start_exclusive
40+
assert parsed.end_inclusive
41+
42+
2843
@given(anys())
2944
def test_anys_to_message_and_back(any_: AnyMessage) -> None:
3045
assert AnyMessage.from_message(any_.to_message()) == any_

tilebox-datasets/tilebox/datasets/aio/dataset.py

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,18 @@
77
import xarray as xr
88
from tqdm.auto import tqdm
99

10+
from _tilebox.grpc.aio.pagination import Pagination as PaginationProtocol
11+
from _tilebox.grpc.aio.pagination import paginated_request
1012
from _tilebox.grpc.aio.producer_consumer import async_producer_consumer
1113
from _tilebox.grpc.error import ArgumentError, NotFoundError
1214
from tilebox.datasets.aio.pagination import (
13-
paginated_request,
1415
with_progressbar,
1516
with_time_progress_callback,
1617
with_time_progressbar,
1718
)
1819
from tilebox.datasets.data.collection import CollectionInfo
1920
from tilebox.datasets.data.data_access import QueryFilters, SpatialFilter, SpatialFilterLike
20-
from tilebox.datasets.data.datapoint import DatapointInterval, DatapointPage, QueryResultPage
21+
from tilebox.datasets.data.datapoint import DatapointInterval, DatapointIntervalLike, DatapointPage, QueryResultPage
2122
from tilebox.datasets.data.datasets import Dataset
2223
from tilebox.datasets.data.pagination import Pagination
2324
from tilebox.datasets.data.time_interval import TimeInterval, TimeIntervalLike
@@ -242,7 +243,7 @@ async def _find_legacy(self, datapoint_id: str, skip_data: bool = False) -> xr.D
242243

243244
async def _find_interval(
244245
self,
245-
datapoint_id_interval: tuple[str, str] | tuple[UUID, UUID],
246+
datapoint_id_interval: DatapointIntervalLike,
246247
end_inclusive: bool = True,
247248
*,
248249
skip_data: bool = False,
@@ -266,18 +267,13 @@ async def _find_interval(
266267
datapoint_id_interval, end_inclusive, skip_data=skip_data, show_progress=show_progress
267268
)
268269

269-
start_id, end_id = datapoint_id_interval
270-
271270
filters = QueryFilters(
272-
temporal_extent=DatapointInterval(
273-
start_id=as_uuid(start_id),
274-
end_id=as_uuid(end_id),
275-
start_exclusive=False,
276-
end_inclusive=end_inclusive,
277-
)
271+
temporal_extent=DatapointInterval.parse(datapoint_id_interval, end_inclusive=end_inclusive)
278272
)
279273

280-
request = partial(self._dataset._service.query, [self._collection.id], filters, skip_data)
274+
async def request(page: PaginationProtocol) -> QueryResultPage:
275+
query_page = Pagination(page.limit, page.starting_after)
276+
return await self._dataset._service.query([self._collection.id], filters, skip_data, query_page)
281277

282278
initial_page = Pagination()
283279
pages = paginated_request(request, initial_page)
@@ -288,27 +284,19 @@ async def _find_interval(
288284

289285
async def _find_interval_legacy(
290286
self,
291-
datapoint_id_interval: tuple[str, str] | tuple[UUID, UUID],
287+
datapoint_id_interval: DatapointIntervalLike,
292288
end_inclusive: bool = True,
293289
*,
294290
skip_data: bool = False,
295291
show_progress: bool = False,
296292
) -> xr.Dataset:
297-
start_id, end_id = datapoint_id_interval
293+
datapoint_interval = DatapointInterval.parse(datapoint_id_interval, end_inclusive=end_inclusive)
298294

299-
datapoint_interval = DatapointInterval(
300-
start_id=as_uuid(start_id),
301-
end_id=as_uuid(end_id),
302-
start_exclusive=False,
303-
end_inclusive=end_inclusive,
304-
)
305-
request = partial(
306-
self._dataset._service.get_dataset_for_datapoint_interval,
307-
str(self._collection.id),
308-
datapoint_interval,
309-
skip_data,
310-
False,
311-
)
295+
async def request(page: PaginationProtocol) -> DatapointPage:
296+
query_page = Pagination(page.limit, page.starting_after)
297+
return await self._dataset._service.get_dataset_for_datapoint_interval(
298+
str(self._collection.id), datapoint_interval, skip_data, False, query_page
299+
)
312300

313301
initial_page = Pagination()
314302
pages = paginated_request(request, initial_page)
@@ -427,9 +415,10 @@ async def _iter_pages(
427415
yield page
428416

429417
async def _load_page(
430-
self, filters: QueryFilters, skip_data: bool, page: Pagination | None = None
418+
self, filters: QueryFilters, skip_data: bool, page: PaginationProtocol | None = None
431419
) -> QueryResultPage:
432-
return await self._dataset._service.query([self._collection.id], filters, skip_data, page)
420+
query_page = Pagination(page.limit, page.starting_after) if page else Pagination()
421+
return await self._dataset._service.query([self._collection.id], filters, skip_data, query_page)
433422

434423
async def _load_legacy(
435424
self,
@@ -472,10 +461,11 @@ async def _iter_pages_legacy(
472461
yield page
473462

474463
async def _load_page_legacy(
475-
self, time_interval: TimeInterval, skip_data: bool, skip_meta: bool, page: Pagination | None = None
464+
self, time_interval: TimeInterval, skip_data: bool, skip_meta: bool, page: PaginationProtocol | None = None
476465
) -> DatapointPage:
466+
query_page = Pagination(page.limit, page.starting_after) if page else Pagination()
477467
return await self._dataset._service.get_dataset_for_time_interval(
478-
str(self._collection.id), time_interval, skip_data, skip_meta, page
468+
str(self._collection.id), time_interval, skip_data, skip_meta, query_page
479469
)
480470

481471
async def ingest(

tilebox-datasets/tilebox/datasets/aio/pagination.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from collections.abc import AsyncIterator, Awaitable, Callable
2+
from collections.abc import AsyncIterator
33
from datetime import datetime, timezone
44
from typing import TypeVar
55

@@ -9,41 +9,11 @@
99
TimeInterval,
1010
)
1111
from tilebox.datasets.data.datapoint import DatapointPage, QueryResultPage
12-
from tilebox.datasets.data.pagination import Pagination
1312
from tilebox.datasets.progress import ProgressCallback, TimeIntervalProgressBar
1413

1514
ResultPage = TypeVar("ResultPage", bound=DatapointPage | QueryResultPage)
1615

1716

18-
async def paginated_request(
19-
paging_request: Callable[[Pagination], Awaitable[ResultPage]],
20-
initial_page: Pagination | None = None,
21-
) -> AsyncIterator[ResultPage]:
22-
"""Make a paginated request to a gRPC service endpoint.
23-
24-
The endpoint is expected to return a next_page field, which is used for subsequent requests. Once no such
25-
next_page field is returned, the request is completed.
26-
27-
Args:
28-
paging_request: A function that takes a page as input and returns a Datapoints object
29-
Often this will be a functools.partial object that wraps a gRPC service endpoint
30-
and only leaves the page argument remaining
31-
initial_page: The initial page to request
32-
33-
Yields:
34-
Datapoints: The individual pages of the response
35-
"""
36-
if initial_page is None:
37-
initial_page = Pagination()
38-
39-
response = await paging_request(initial_page)
40-
yield response
41-
42-
while response.next_page.starting_after is not None:
43-
response = await paging_request(response.next_page)
44-
yield response
45-
46-
4717
async def with_progressbar(
4818
paginated_request: AsyncIterator[ResultPage],
4919
progress_description: str,

tilebox-datasets/tilebox/datasets/data/datapoint.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass, field
22
from datetime import datetime
3-
from typing import Any
3+
from typing import Any, TypeAlias
44
from uuid import UUID
55

66
from tilebox.datasets.data.pagination import Pagination
@@ -9,6 +9,8 @@
99
from tilebox.datasets.datasetsv1 import core_pb2, data_access_pb2, data_ingestion_pb2
1010
from tilebox.datasets.message_pool import get_message_type
1111

12+
DatapointIntervalLike: TypeAlias = "tuple[str, str] | tuple[UUID, UUID] | DatapointInterval"
13+
1214

1315
@dataclass(frozen=True)
1416
class DatapointInterval:
@@ -34,6 +36,47 @@ def to_message(self) -> core_pb2.DatapointInterval:
3436
end_inclusive=self.end_inclusive,
3537
)
3638

39+
@classmethod
40+
def parse(
41+
cls, arg: DatapointIntervalLike, start_exclusive: bool = False, end_inclusive: bool = True
42+
) -> "DatapointInterval":
43+
"""
44+
Convert a variety of input types to a DatapointInterval.
45+
46+
Supported input types:
47+
- DatapointInterval: Return the input as is
48+
- tuple of two UUIDs: Return an DatapointInterval with start and end id set to the given values
49+
- tuple of two strings: Return an DatapointInterval with start and end id set to the UUIDs parsed from the given strings
50+
51+
Args:
52+
arg: The input to convert
53+
start_exclusive: Whether the start id is exclusive
54+
end_inclusive: Whether the end id is inclusive
55+
56+
Returns:
57+
DatapointInterval: The parsed ID interval
58+
"""
59+
60+
match arg:
61+
case DatapointInterval(_, _, _, _):
62+
return arg
63+
case (UUID(), UUID()):
64+
start, end = arg
65+
return DatapointInterval(
66+
start_id=start,
67+
end_id=end,
68+
start_exclusive=start_exclusive,
69+
end_inclusive=end_inclusive,
70+
)
71+
case (str(), str()):
72+
start, end = arg
73+
return DatapointInterval(
74+
start_id=UUID(start),
75+
end_id=UUID(end),
76+
start_exclusive=start_exclusive,
77+
end_inclusive=end_inclusive,
78+
)
79+
3780

3881
@dataclass(frozen=True)
3982
class AnyMessage:

0 commit comments

Comments
 (0)