Skip to content

Commit c47512b

Browse files
authored
python(feat): add export api for sift client (#490)
1 parent ce4b303 commit c47512b

17 files changed

Lines changed: 1156 additions & 4 deletions

File tree

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, cast
4+
5+
from sift.calculated_channels.v2.calculated_channels_pb2 import (
6+
CalculatedChannelAbstractChannelReference,
7+
)
8+
from sift.exports.v1.exports_pb2 import (
9+
AssetsAndTimeRange,
10+
CalculatedChannelConfig,
11+
ExportDataRequest,
12+
ExportDataResponse,
13+
ExportOptions,
14+
GetDownloadUrlRequest,
15+
GetDownloadUrlResponse,
16+
RunsAndTimeRange,
17+
TimeRange,
18+
)
19+
from sift.exports.v1.exports_pb2_grpc import ExportServiceStub
20+
21+
from sift_client._internal.low_level_wrappers.base import LowLevelClientBase
22+
from sift_client._internal.util.timestamp import to_pb_timestamp
23+
from sift_client.sift_types.calculated_channel import CalculatedChannel, CalculatedChannelCreate
24+
from sift_client.transport import WithGrpcClient
25+
26+
if TYPE_CHECKING:
27+
from datetime import datetime
28+
29+
from sift_client.sift_types.export import ExportOutputFormat
30+
from sift_client.transport.grpc_transport import GrpcClient
31+
32+
33+
def _build_calc_channel_configs(
34+
calculated_channels: list[CalculatedChannel | CalculatedChannelCreate] | None,
35+
) -> list[CalculatedChannelConfig]:
36+
"""Convert high-level calculated channel objects to proto CalculatedChannelConfig messages."""
37+
if not calculated_channels:
38+
return []
39+
configs = []
40+
for cc in calculated_channels:
41+
if isinstance(cc, CalculatedChannelCreate):
42+
refs = cc.expression_channel_references or []
43+
else:
44+
refs = cc.channel_references
45+
configs.append(
46+
CalculatedChannelConfig(
47+
name=cc.name,
48+
expression=cc.expression or "",
49+
channel_references=[
50+
CalculatedChannelAbstractChannelReference(
51+
channel_reference=ref.channel_reference,
52+
channel_identifier=ref.channel_identifier,
53+
)
54+
for ref in refs
55+
],
56+
units=cc.units,
57+
)
58+
)
59+
return configs
60+
61+
62+
class ExportsLowLevelClient(LowLevelClientBase, WithGrpcClient):
63+
"""Low-level client for the DataExportAPI.
64+
65+
This class provides a thin wrapper around the autogenerated gRPC bindings for the DataExportAPI.
66+
"""
67+
68+
def __init__(self, grpc_client: GrpcClient):
69+
"""Initialize the ExportsLowLevelClient.
70+
71+
Args:
72+
grpc_client: The gRPC client to use for making API calls.
73+
"""
74+
super().__init__(grpc_client)
75+
76+
async def export_data(
77+
self,
78+
*,
79+
output_format: ExportOutputFormat,
80+
run_ids: list[str] | None = None,
81+
asset_ids: list[str] | None = None,
82+
start_time: datetime | None = None,
83+
stop_time: datetime | None = None,
84+
channel_ids: list[str] | None = None,
85+
calculated_channels: list[CalculatedChannel | CalculatedChannelCreate] | None = None,
86+
simplify_channel_names: bool = False,
87+
combine_runs: bool = False,
88+
split_export_by_asset: bool = False,
89+
split_export_by_run: bool = False,
90+
) -> str:
91+
"""Initiate a data export.
92+
93+
Builds the ExportDataRequest proto and makes the gRPC call.
94+
Sets whichever time_selection oneof fields are provided
95+
(run_ids, asset_ids, or time range); the server validates
96+
the request.
97+
98+
Returns:
99+
The job ID for the background export.
100+
"""
101+
request = ExportDataRequest(
102+
output_format=output_format.value,
103+
export_options=ExportOptions(
104+
use_legacy_format=False,
105+
simplify_channel_names=simplify_channel_names,
106+
combine_runs=combine_runs,
107+
split_export_by_asset=split_export_by_asset,
108+
split_export_by_run=split_export_by_run,
109+
),
110+
channel_ids=channel_ids or [],
111+
calculated_channel_configs=_build_calc_channel_configs(calculated_channels),
112+
)
113+
114+
if run_ids is not None:
115+
runs_and_time_range = RunsAndTimeRange(run_ids=run_ids)
116+
if start_time:
117+
runs_and_time_range.start_time.CopyFrom(to_pb_timestamp(start_time))
118+
if stop_time:
119+
runs_and_time_range.stop_time.CopyFrom(to_pb_timestamp(stop_time))
120+
request.runs_and_time_range.CopyFrom(runs_and_time_range)
121+
122+
if asset_ids is not None:
123+
assets_and_time_range = AssetsAndTimeRange(asset_ids=asset_ids)
124+
if start_time:
125+
assets_and_time_range.start_time.CopyFrom(to_pb_timestamp(start_time))
126+
if stop_time:
127+
assets_and_time_range.stop_time.CopyFrom(to_pb_timestamp(stop_time))
128+
request.assets_and_time_range.CopyFrom(assets_and_time_range)
129+
130+
if run_ids is None and asset_ids is None:
131+
time_range = TimeRange()
132+
if start_time:
133+
time_range.start_time.CopyFrom(to_pb_timestamp(start_time))
134+
if stop_time:
135+
time_range.stop_time.CopyFrom(to_pb_timestamp(stop_time))
136+
request.time_range.CopyFrom(time_range)
137+
138+
response = await self._grpc_client.get_stub(ExportServiceStub).ExportData(request)
139+
response = cast("ExportDataResponse", response)
140+
return response.job_id
141+
142+
async def get_download_url(self, job_id: str) -> str:
143+
"""Get the download URL for a background export job.
144+
145+
Args:
146+
job_id: The job ID returned from export_data.
147+
148+
Returns:
149+
The presigned URL to download the exported zip file.
150+
"""
151+
request = GetDownloadUrlRequest(job_id=job_id)
152+
response = await self._grpc_client.get_stub(ExportServiceStub).GetDownloadUrl(request)
153+
response = cast("GetDownloadUrlResponse", response)
154+
return response.presigned_url
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from sift_client.sift_types.calculated_channel import CalculatedChannel, CalculatedChannelCreate
6+
from sift_client.sift_types.channel import ChannelReference
7+
8+
if TYPE_CHECKING:
9+
from sift_client.resources.channels import ChannelsAPIAsync
10+
11+
12+
async def resolve_calculated_channels(
13+
calculated_channels: list[CalculatedChannel | CalculatedChannelCreate] | None,
14+
channels_api: ChannelsAPIAsync,
15+
) -> list[CalculatedChannel | CalculatedChannelCreate] | None:
16+
"""Resolve channel reference identifiers from names to UUIDs.
17+
18+
For each channel reference, looks up the identifier as a channel name.
19+
If found, replaces it with the channel's UUID. If not found, assumes
20+
the identifier is already a UUID and keeps it as-is.
21+
"""
22+
if not calculated_channels:
23+
return None
24+
25+
resolved: list[CalculatedChannel | CalculatedChannelCreate] = []
26+
for cc in calculated_channels:
27+
refs = (
28+
(cc.expression_channel_references or [])
29+
if isinstance(cc, CalculatedChannelCreate)
30+
else cc.channel_references
31+
)
32+
33+
resolved_refs: list[ChannelReference] = []
34+
for ref in refs:
35+
channel = await channels_api.find(
36+
name=ref.channel_identifier,
37+
assets=cc.asset_ids,
38+
)
39+
if channel is not None:
40+
ref = ChannelReference(
41+
channel_reference=ref.channel_reference,
42+
channel_identifier=channel._id_or_error,
43+
)
44+
resolved_refs.append(ref)
45+
46+
resolved.append(
47+
CalculatedChannelCreate(
48+
name=cc.name,
49+
expression=cc.expression,
50+
expression_channel_references=resolved_refs,
51+
units=cc.units or None,
52+
)
53+
)
54+
return resolved
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from typing import Any, Callable
5+
6+
7+
async def run_sync_function(fn: Callable[..., Any], *args: Any) -> Any:
8+
"""Run a synchronous function in a thread pool to avoid blocking the event loop."""
9+
loop = asyncio.get_running_loop()
10+
return await loop.run_in_executor(None, fn, *args)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from __future__ import annotations
2+
3+
import warnings
4+
import zipfile
5+
from typing import TYPE_CHECKING
6+
7+
from sift_client.errors import SiftWarning
8+
9+
if TYPE_CHECKING:
10+
from pathlib import Path
11+
12+
from sift_client.transport.rest_transport import RestClient
13+
14+
15+
def download_file(signed_url: str, output_path: Path, *, rest_client: RestClient) -> Path:
16+
"""Download a file from a URL in streaming 4 MiB chunks.
17+
18+
Args:
19+
url: The URL to download from.
20+
dest: Path where the file will be saved. Parent directories are created if needed.
21+
rest_client: The SDK rest client to use for the download.
22+
23+
Returns:
24+
The path to the downloaded file.
25+
26+
Raises:
27+
requests.HTTPError: If the download request fails.
28+
"""
29+
output_path.parent.mkdir(parents=True, exist_ok=True)
30+
# Strip the session's default Authorization header, presigned URLs carry their own auth
31+
with rest_client.get(signed_url, stream=True, headers={"Authorization": None}) as response:
32+
response.raise_for_status()
33+
with output_path.open("wb") as file:
34+
for chunk in response.iter_content(chunk_size=4194304): # 4 MiB
35+
if chunk:
36+
file.write(chunk)
37+
return output_path
38+
39+
40+
def extract_zip(zip_path: Path, output_dir: Path, *, delete_zip: bool = True) -> list[Path]:
41+
"""Extract a zip file to a directory.
42+
43+
Args:
44+
zip_path: Path to the zip file.
45+
output_dir: Directory to extract contents into. Created if it doesn't exist.
46+
delete_zip: If True (default), delete the zip file after extraction.
47+
48+
Returns:
49+
List of paths to the extracted files (excludes directories).
50+
51+
Raises:
52+
zipfile.BadZipFile: If the file is not a valid zip.
53+
"""
54+
output_dir.mkdir(parents=True, exist_ok=True)
55+
with zipfile.ZipFile(zip_path, "r") as zip_file:
56+
names = zip_file.namelist()
57+
zip_file.extractall(output_dir)
58+
if delete_zip:
59+
try:
60+
zip_path.unlink()
61+
except OSError:
62+
warnings.warn(f"Failed to delete zip file '{zip_path}'", SiftWarning, stacklevel=2)
63+
return [output_dir / name for name in names if not name.endswith("/")]
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from __future__ import annotations
2+
3+
from unittest.mock import AsyncMock, MagicMock
4+
5+
import pytest
6+
7+
from sift_client._internal.util.channels import resolve_calculated_channels
8+
from sift_client.sift_types.calculated_channel import (
9+
CalculatedChannel,
10+
CalculatedChannelCreate,
11+
ChannelReference,
12+
)
13+
from sift_client.sift_types.channel import Channel
14+
15+
16+
class TestResolveCalculatedChannels:
17+
@pytest.mark.asyncio
18+
async def test_none_passthrough(self):
19+
api = MagicMock()
20+
api.find = AsyncMock(return_value=None)
21+
assert await resolve_calculated_channels(None, channels_api=api) is None
22+
23+
@pytest.mark.asyncio
24+
async def test_resolves_name_to_uuid(self):
25+
mock_ch = MagicMock(spec=Channel)
26+
mock_ch._id_or_error = "resolved-uuid"
27+
api = MagicMock()
28+
api.find = AsyncMock(return_value=mock_ch)
29+
30+
cc = MagicMock(spec=CalculatedChannel)
31+
cc.name, cc.expression, cc.units = "calc", "$1 + 10", "m/s"
32+
cc.asset_ids = ["asset-1"]
33+
cc.channel_references = [
34+
ChannelReference(channel_reference="$1", channel_identifier="sensor.vel")
35+
]
36+
37+
result = await resolve_calculated_channels([cc], channels_api=api)
38+
assert result is not None
39+
assert len(result) == 1
40+
refs = result[0].expression_channel_references
41+
assert refs is not None
42+
assert refs[0].channel_identifier == "resolved-uuid"
43+
44+
@pytest.mark.asyncio
45+
async def test_keeps_identifier_when_not_found(self):
46+
api = MagicMock()
47+
api.find = AsyncMock(return_value=None)
48+
cc = CalculatedChannelCreate(
49+
name="x",
50+
expression="$1",
51+
units="m",
52+
expression_channel_references=[
53+
ChannelReference(channel_reference="$1", channel_identifier="ch-1")
54+
],
55+
)
56+
result = await resolve_calculated_channels([cc], channels_api=api)
57+
assert result is not None
58+
assert result[0] == cc

0 commit comments

Comments
 (0)