diff --git a/.github/workflows/python_ci.yaml b/.github/workflows/python_ci.yaml index dee4a54f8..4dbac8c5b 100644 --- a/.github/workflows/python_ci.yaml +++ b/.github/workflows/python_ci.yaml @@ -27,7 +27,7 @@ jobs: - name: Pip install run: | python -m pip install --upgrade pip - pip install '.[development,openssl,tdms,rosbags,hdf5]' + pip install '.[development,openssl,tdms,rosbags,hdf5,sift-stream]' - name: Lint run: | ruff check diff --git a/python/lib/sift_client/__init__.py b/python/lib/sift_client/__init__.py index edbe5eaba..2ced0916c 100644 --- a/python/lib/sift_client/__init__.py +++ b/python/lib/sift_client/__init__.py @@ -203,6 +203,9 @@ async def get_asset_async(): """ +import logging +import sys + from sift_client.client import SiftClient from sift_client.transport import SiftConnectionConfig @@ -210,3 +213,12 @@ async def get_asset_async(): "SiftClient", "SiftConnectionConfig", ] + +logger = logging.getLogger("sift_client") +logging.basicConfig( + level=logging.ERROR, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) + + +handler = logging.StreamHandler(sys.stdout) +logger.addHandler(handler) diff --git a/python/lib/sift_client/_internal/low_level_wrappers/__init__.py b/python/lib/sift_client/_internal/low_level_wrappers/__init__.py index b83c28d11..6bdef7f17 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/__init__.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/__init__.py @@ -2,12 +2,18 @@ from sift_client._internal.low_level_wrappers.calculated_channels import ( CalculatedChannelsLowLevelClient, ) +from sift_client._internal.low_level_wrappers.channels import ChannelsLowLevelClient +from sift_client._internal.low_level_wrappers.ingestion import IngestionLowLevelClient from sift_client._internal.low_level_wrappers.ping import PingLowLevelClient +from sift_client._internal.low_level_wrappers.rules import RulesLowLevelClient from sift_client._internal.low_level_wrappers.runs import RunsLowLevelClient __all__ = [ "AssetsLowLevelClient", "CalculatedChannelsLowLevelClient", + "ChannelsLowLevelClient", + "IngestionLowLevelClient", "PingLowLevelClient", + "RulesLowLevelClient", "RunsLowLevelClient", ] diff --git a/python/lib/sift_client/_internal/low_level_wrappers/calculated_channels.py b/python/lib/sift_client/_internal/low_level_wrappers/calculated_channels.py index a16c01bf1..6ca764b37 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/calculated_channels.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/calculated_channels.py @@ -23,7 +23,7 @@ from sift.calculated_channels.v2.calculated_channels_pb2_grpc import CalculatedChannelServiceStub from sift_client._internal.low_level_wrappers.base import LowLevelClientBase -from sift_client.transport.grpc_transport import GrpcClient +from sift_client.transport import GrpcClient, WithGrpcClient from sift_client.types.calculated_channel import ( CalculatedChannel, CalculatedChannelUpdate, @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) -class CalculatedChannelsLowLevelClient(LowLevelClientBase): +class CalculatedChannelsLowLevelClient(LowLevelClientBase, WithGrpcClient): """ Low-level client for the CalculatedChannelsAPI. @@ -47,7 +47,7 @@ def __init__(self, grpc_client: GrpcClient): Args: grpc_client: The gRPC client to use for making API calls. """ - self._grpc_client = grpc_client + super().__init__(grpc_client) async def get_calculated_channel( self, diff --git a/python/lib/sift_client/_internal/low_level_wrappers/channels.py b/python/lib/sift_client/_internal/low_level_wrappers/channels.py new file mode 100644 index 000000000..f6d30de75 --- /dev/null +++ b/python/lib/sift_client/_internal/low_level_wrappers/channels.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import logging +from typing import Any, cast + +from sift.channels.v3.channels_pb2 import ( + GetChannelRequest, + GetChannelResponse, + ListChannelsRequest, + ListChannelsResponse, +) +from sift.channels.v3.channels_pb2_grpc import ChannelServiceStub + +from sift_client._internal.low_level_wrappers.base import LowLevelClientBase +from sift_client.transport import WithGrpcClient +from sift_client.transport.grpc_transport import GrpcClient +from sift_client.types.channel import Channel + +# Configure logging +logger = logging.getLogger(__name__) + +CHANNELS_DEFAULT_PAGE_SIZE = 10_000 + + +class ChannelsLowLevelClient(LowLevelClientBase, WithGrpcClient): + """ + Low-level client for the ChannelsAPI. + + This class provides a thin wrapper around the autogenerated bindings for the ChannelsAPI. + """ + + def __init__(self, grpc_client: GrpcClient): + """ + Initialize the ChannelsLowLevelClient. + + Args: + grpc_client: The gRPC client to use for making API calls. + """ + super().__init__(grpc_client) + + async def get_channel(self, channel_id: str) -> Channel: + """ + Get a channel by channel_id. + + Args: + channel_id: The channel ID to get. + + Returns: + The Channel. + + Raises: + ValueError: If channel_id is not provided. + """ + + request = GetChannelRequest(channel_id=channel_id) + response = await self._grpc_client.get_stub(ChannelServiceStub).GetChannel(request) + grpc_channel = cast(GetChannelResponse, response).channel + channel = Channel._from_proto(grpc_channel) + return channel + + async def list_channels( + self, + *, + page_size: int | None = None, + page_token: str | None = None, + query_filter: str | None = None, + order_by: str | None = None, + ) -> tuple[list[Channel], str]: + """ + List channels with optional filtering and pagination. + + Args: + page_size: The maximum number of channels to return. + page_token: A page token for pagination. + query_filter: A CEL filter string. + order_by: How to order the retrieved channels. + + Returns: + A tuple of (channels, next_page_token). + """ + + request_kwargs: dict[str, Any] = {} + if query_filter: + request_kwargs["filter"] = query_filter + if order_by: + request_kwargs["order_by"] = order_by + if page_size: + request_kwargs["page_size"] = page_size + if page_token: + request_kwargs["page_token"] = page_token + + request = ListChannelsRequest(**request_kwargs) + response = await self._grpc_client.get_stub(ChannelServiceStub).ListChannels(request) + response = cast(ListChannelsResponse, response) + + channels = [Channel._from_proto(channel) for channel in response.channels] + return channels, response.next_page_token + + async def list_all_channels( + self, + *, + query_filter: str | None = None, + order_by: str | None = None, + max_results: int | None = None, + ) -> list[Channel]: + """ + List all channels with optional filtering. + + Args: + query_filter: A CEL filter string. + order_by: How to order the retrieved channels. + max_results: Maximum number of results to return. + + Returns: + A list of all matching channels. + """ + # Channels default page size is 10,000 so lower it if we're passing max_results + page_size = None + if max_results is not None and max_results <= CHANNELS_DEFAULT_PAGE_SIZE: + page_size = max_results + return await self._handle_pagination( + self.list_channels, + kwargs={"query_filter": query_filter}, + page_size=page_size, + order_by=order_by, + max_results=max_results, + ) diff --git a/python/lib/sift_client/_internal/low_level_wrappers/data.py b/python/lib/sift_client/_internal/low_level_wrappers/data.py new file mode 100644 index 000000000..d48339bb1 --- /dev/null +++ b/python/lib/sift_client/_internal/low_level_wrappers/data.py @@ -0,0 +1,357 @@ +from __future__ import annotations + +import asyncio +import logging +from datetime import datetime, timezone +from math import ceil +from typing import Any, List, Tuple, cast + +import pandas as pd +from pydantic import BaseModel, ConfigDict +from sift.data.v2.data_pb2 import ( + BitFieldValues, + ChannelQuery, + GetDataRequest, + GetDataResponse, + Query, +) +from sift.data.v2.data_pb2_grpc import DataServiceStub +from sift_py._internal.time import to_timestamp_nanos + +from sift_client._internal.low_level_wrappers.base import LowLevelClientBase +from sift_client.transport import WithGrpcClient +from sift_client.transport.grpc_transport import GrpcClient +from sift_client.types.channel import Channel, ChannelDataType + +# Configure logging +logger = logging.getLogger(__name__) + +CHANNELS_DEFAULT_PAGE_SIZE = 10_000 +# TODO: There is a pagination issue API side when requesting multiple channels in single request. +# If all data points for all channels in a single request don't fit into a single page, then +# paging seems to omit all but a single channel. We can increase this batch size once that issue +# has been resolved. In the mean time each channel gets its own request. +REQUEST_BATCH_SIZE = 1 + + +class ChannelCacheEntry(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + data: pd.DataFrame + start_time: datetime + end_time: datetime + + +class ChannelCache(BaseModel): + name_id_map: dict[str, str] + channels: dict[str, ChannelCacheEntry] + + +class DataLowLevelClient(LowLevelClientBase, WithGrpcClient): + """ + Low-level client for fetching channel data. + + This class provides a thin wrapper around the autogenerated bindings for the DataAPI. + """ + + channel_cache: ChannelCache = ChannelCache(name_id_map={}, channels={}) + + def __init__(self, grpc_client: GrpcClient): + """ + Initialize the DataLowLevelClient. + + Args: + grpc_client: The gRPC client to use for making API calls. + """ + super().__init__(grpc_client) + + def _update_name_id_map(self, channels: list[Channel]): + """ + Update the name id map with the new channels. + """ + for channel in channels: + if channel.bit_field_elements: + for bit_field_element in channel.bit_field_elements: + self.channel_cache.name_id_map[channel.name + "." + bit_field_element.name] = ( + str(channel.id_) + ) + self.channel_cache.name_id_map[channel.name] = str(channel.id_) + + # TODO: Cache calls. Only read cache if end_time is more than 30 min in the past. + # Also, consider manually caching full channel data and evaluating start/end times while ignoring pagination. Do this ful caching at a higher level though to handle case where pagination fails. + async def _get_data_impl( + self, + *, + channel_ids: list[str], + run_id: str | None = None, + start_time: datetime | None = None, + end_time: datetime, + page_size: int | None = None, + page_token: str | None = None, + order_by: str | None = None, + ) -> Tuple[List[Any], str | None]: + """ + Get the data for a channel during a run. + """ + queries = [ + Query(channel=ChannelQuery(channel_id=channel_id, run_id=run_id)) + for channel_id in channel_ids + ] + request_kwargs: dict[str, Any] = { + "queries": queries, + "sample_ms": 0, + "start_time": start_time, + "end_time": end_time, + "page_size": page_size, + "page_token": page_token, + } + + request = GetDataRequest(**request_kwargs) + response = await self._grpc_client.get_stub(DataServiceStub).GetData(request) + response = cast(GetDataResponse, response) + return response.data, response.next_page_token # type: ignore # mypy doesn't know RepeatedCompositeFieldContainer can be treated like a list + + def _filter_cached_channels(self, channel_ids: List[str]) -> Tuple[List[str], List[str]]: + cached_channels = [] + not_cached_channels = [] + for id in channel_ids: + if self.channel_cache.channels.get(id): + cached_channels.append(id) + else: + not_cached_channels.append(id) + return cached_channels, not_cached_channels + + def _check_cache( + self, + *, + channel_id: str, + start_time: datetime, + end_time: datetime, + run_id: str | None = None, + ) -> Tuple[pd.DataFrame | None, datetime | None, datetime | None]: + """ + Check if the data for a channel during a run is cached and return how to query remaining data if so. + + There are a variety of requested start/end time vs cached start/end time cases to consider. + Below diagram represents time aligned ranges for each case: + + Cache interval: |-------------------------------| + Case 1: |---------------------------| + Case 2: |--------------------------------| + Case 3: |----------| + Case 4: |--------------------------------| + Case 5: |------| or |-----------------------------------------| + + Returns: + A tuple of (data, start_time, end_time) + where data is a pandas dataframe and start and end times are what should be used for the next call based on what is not covered by the cached data. + """ + cached_data = self.channel_cache.channels.get(channel_id) + ret_start_time = start_time + ret_end_time = end_time + ret_data = None + if cached_data: + start_time_cached = cached_data.start_time + end_time_cached = cached_data.end_time + ret_data = cached_data.data + # Filter data to desiredtime range + ret_data = ret_data[start_time:end_time] # type: ignore # mypy doesn't understand pandas that well seemingly + + if start_time_cached <= start_time: + if start_time < end_time_cached: + if end_time <= end_time_cached: + # Case 1 + ret_start_time = None # type: ignore + ret_end_time = None # type: ignore + else: + # Case 2 + ret_start_time = end_time_cached + ret_end_time = end_time + else: + # Case 3 + return (None, start_time, end_time) + else: + if start_time_cached < end_time and end_time <= end_time_cached: + # Case 4 + ret_start_time = start_time + ret_end_time = start_time_cached + else: + # Case 5 + return (None, start_time, end_time) + + return (ret_data, ret_start_time, ret_end_time) + + def _update_cache( + self, + *, + channel_data: dict[str, pd.DataFrame], + start_time: datetime, + end_time: datetime, + run_id: str | None = None, + ): + """ + Update the cache with the new data and start/end times. + """ + assert start_time is not None + assert end_time is not None + name_id_map = self.channel_cache.name_id_map + + for channel_name, data in channel_data.items(): + channel_id = name_id_map.get(channel_name) + if not channel_id: + raise ValueError( + f"{channel_name} not found in name_id_map. Not sure got data for this channel without a call that should've updated the map." + ) + + suggested_start_time = start_time + if run_id: + if len(data) > 0: + suggested_start_time = data.index[0] + else: + # Because we didn't get any data, we can't know what the start time should be. + # And because this was queried w/ a run ID, we can't say there's no data before the run started. + # So we just don't update the cache. + continue + + if channel_id in self.channel_cache.channels: + self.channel_cache.channels[channel_id].data = ( + pd.concat([self.channel_cache.channels[channel_id].data, data]) + .groupby(level=0) + .last() + ) + self.channel_cache.channels[channel_id].start_time = min( + suggested_start_time, self.channel_cache.channels[channel_id].start_time + ) + self.channel_cache.channels[channel_id].end_time = max( + end_time, self.channel_cache.channels[channel_id].end_time + ) + else: + self.channel_cache.channels[channel_id] = ChannelCacheEntry( + data=data, + start_time=suggested_start_time, + end_time=end_time, + ) + + async def get_channel_data( + self, + *, + channels: List[Channel], + run_id: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + limit: int | None = None, + ignore_cache: bool = False, + ): + """ + Get the data for a channel during a run. + """ + ret_data = {} + # No data will be returned if end_time is not provided. + start_time = start_time or datetime.fromtimestamp(0, tz=timezone.utc) + end_time = end_time or datetime.now(timezone.utc) + + self._update_name_id_map(channels) + channel_ids = [c.id_ for c in channels] + cached_channels, not_cached_channels = ( + ([], channel_ids) if ignore_cache else self._filter_cached_channels(channel_ids) # type: ignore + ) + + tasks = [] + page_size = limit if limit and limit < 1000 else 1000 + limit = ceil(limit / page_size) if limit else 10 + # Queue up calls for non-cached channels in batches. + batch_size = REQUEST_BATCH_SIZE + for i in range(0, len(not_cached_channels), batch_size): # type: ignore + batch = not_cached_channels[i : i + batch_size] # type: ignore + + task = asyncio.create_task( + self._handle_pagination( + self._get_data_impl, + kwargs={ + "channel_ids": batch, + "run_id": run_id, + "start_time": start_time, + "end_time": end_time, + }, + page_size=page_size, + max_results=limit, + ) + ) + tasks.append(task) + + # Handling cached channels 1 by 1 instead of in batches to account for channels that may have been cached from calls with different start/end times. + for channel_id in cached_channels: + cached_data, new_start_time, new_end_time = self._check_cache( + channel_id=channel_id, + start_time=start_time, + end_time=end_time, + run_id=run_id, + ) + + if cached_data is not None: + for name in cached_data.columns: + ret_data[name] = cached_data + if new_start_time is None: + # Cache fully encompassed the desired time range so don't queue a call. + continue + task = asyncio.create_task( + self._handle_pagination( + self._get_data_impl, + kwargs={ + "channel_ids": [channel_id], + "run_id": run_id, + "start_time": new_start_time, + "end_time": new_end_time or end_time, + }, + page_size=page_size, + max_results=limit, + ) + ) + tasks.append(task) + + pages = await asyncio.gather(*tasks) + # Flatten the data + for page in pages: + for data in page: + page_results = self.try_deserialize_channel_data(data) + for name, df in page_results.items(): + if name not in ret_data: + ret_data[name] = df + else: + ret_data[name] = pd.concat([ret_data[name], df]).groupby(level=0).last() + + self._update_cache( + channel_data=ret_data, start_time=start_time, end_time=end_time, run_id=run_id + ) + + return ret_data + + @staticmethod + def try_deserialize_channel_data(channel_data: Any) -> dict[str, pd.DataFrame]: + """ + Deserialize a channel data object into a numpy array. + """ + data_type = ChannelDataType.from_str(channel_data.type_url) + if data_type is None: + raise ValueError(f"Unknown data type: {channel_data.type_url}") + + proto_data_class = ChannelDataType.proto_data_class(data_type) + proto_data_value = proto_data_class.FromString(channel_data.value) + metadata = proto_data_value.metadata + ret_data = {} + + components = ( + proto_data_value.values if proto_data_class == BitFieldValues else [proto_data_value] + ) + for component in components: + name = metadata.channel.name + time_column = [] + value_column = [] + if proto_data_class == BitFieldValues: + name += "." + component.name + for value_obj in component.values: + time_column.append(to_timestamp_nanos(value_obj.timestamp)) + value_column.append(value_obj.value) + df = pd.DataFrame({name: value_column}, index=time_column) + ret_data[name] = df + + return ret_data diff --git a/python/lib/sift_client/_internal/low_level_wrappers/ingestion.py b/python/lib/sift_client/_internal/low_level_wrappers/ingestion.py new file mode 100644 index 000000000..3eefe4ae2 --- /dev/null +++ b/python/lib/sift_client/_internal/low_level_wrappers/ingestion.py @@ -0,0 +1,431 @@ +""" +Low-level wrapper for the IngestionAPI. + +This module provides thin wrappers around the autogenerated bindings for the IngestionAPI. +It handles common concerns like error handling and retries. + +It provides an asynchronous client for the IngestionAPI. +""" + +from __future__ import annotations + +import asyncio +import atexit +import hashlib +import logging +import threading +import time +from collections import namedtuple +from datetime import datetime +from queue import Queue +from typing import Any, Dict, List, cast + +import sift_stream_bindings +from sift.ingestion_configs.v2.ingestion_configs_pb2 import ( + GetIngestionConfigRequest, + ListIngestionConfigFlowsResponse, + ListIngestionConfigsRequest, + ListIngestionConfigsResponse, +) +from sift.ingestion_configs.v2.ingestion_configs_pb2_grpc import IngestionConfigServiceStub +from sift_stream_bindings import ( + IngestionConfigFormPy, + IngestWithConfigDataStreamRequestPy, +) + +from sift_client._internal.low_level_wrappers.base import ( + LowLevelClientBase, +) +from sift_client.transport import GrpcClient, WithGrpcClient +from sift_client.types.ingestion import Flow, IngestionConfig, _to_rust_value +from sift_client.util import cel_utils as cel +from sift_client.util.timestamp import to_rust_py_timestamp + +logger = logging.getLogger(__name__) + + +class IngestionThread(threading.Thread): + """ + Manages ingestion for a single ingestion config. + """ + + IDLE_LOOP_PERIOD = 0.1 # Time of intervals loop will sleep while waiting for data. + SIFT_STREAM_FINISH_TIMEOUT = 0.06 # Measured ~0.05s to finish stream. + CLEANUP_TIMEOUT = IDLE_LOOP_PERIOD + SIFT_STREAM_FINISH_TIMEOUT + + def __init__( + self, + sift_stream_builder: sift_stream_bindings.SiftStreamBuilderPy, + data_queue: Queue, + ingestion_config: IngestionConfigFormPy, + no_data_timeout: int = 1, + metric_interval: float = 0.5, + ): + """ + Initialize the IngestionThread. + + Args: + sift_stream_builder: The sift stream builder to build a new stream. + data_queue: The queue to put IngestWithConfigDataStreamRequestPy requests into for ingestion. + ingestion_config: The ingestion config to use for ingestion. + no_data_timeout: The number of (whole number) seconds to wait for data before stopping the thread (Saves minorly on startup resources. Ingesting new data will always restart the thread if it is stopped). + metric_interval: Time (seconds) to wait between logging metrics. + """ + super().__init__(daemon=True) + self.data_queue = data_queue + self._stop_event = threading.Event() + self.sift_stream_builder = sift_stream_builder + self.ingestion_config = ingestion_config + self.no_data_timeout = no_data_timeout + self.metric_interval = metric_interval + self.initialized = False + + def stop(self): + self._stop_event.set() + # Give a brief chance to finish the stream (should take < 50ms). + time.sleep(self.CLEANUP_TIMEOUT) + self.task.cancel() + + async def await_stream_build(self): + while not self.initialized: + await asyncio.sleep(0.01) + + async def main(self): + logger.debug("Ingestion thread started") + self.sift_stream_builder.ingestion_config = self.ingestion_config + sift_stream = await self.sift_stream_builder.build() + time_of_last_metric = time.time() + time_of_last_data = time.time() + count = 0 + self.initialized = True + try: + while True: + while not self.data_queue.empty(): + if self._stop_event.is_set(): + # Being forced to stop. Try to finish the stream. + logger.info( + f"Ingestion thread received stop signal. Exiting. Sent {count} requests. {self.data_queue.qsize()} requests remaining." + ) + await sift_stream.finish() + return + time_of_last_metric = time.time() + item = self.data_queue.get() + sift_stream = await sift_stream.send_requests(item) + count += 1 + time_since_last_metric = time.time() - time_of_last_metric + if time_since_last_metric > self.metric_interval: + logger.debug( + f"Ingestion thread sent {count} requests, remaining: {self.data_queue.qsize()}" + ) + time_of_last_metric = time.time() + + # Queue empty, check if we should stop. + time_since_last_data = time.time() - time_of_last_data + if self._stop_event.is_set() or time_since_last_data > self.no_data_timeout: + logger.debug( + f"No more requests. Stopping. Sent {count} requests. {self.data_queue.qsize()} requests remaining." + ) + await sift_stream.finish() + return + else: + await asyncio.sleep(self.IDLE_LOOP_PERIOD) + + except asyncio.CancelledError: + # It's possible the thread was joined while sleeping waiting for data. Only note error if we have data left. + if self.data_queue.qsize() > 0: + logger.error( + f"Ingestion thread cancelled without finishing stream. {self.data_queue.qsize()} requests were not sent." + ) + + async def _run(self): + self.task = asyncio.create_task(self.main()) + await self.task + + def run(self): + """This thread will handle sending data to Sift.""" + # Even thought this is a thread, we need to run this async task to await send_requests otherwise we get sift_stream consumed errors. + asyncio.run(self._run()) + + +class IngestionLowLevelClient(LowLevelClientBase, WithGrpcClient): + """ + Low-level client for the IngestionAPI. + + This class provides a thin wrapper around the autogenerated bindings for the IngestionAPI. + It handles common concerns like error handling and retries. + """ + + CacheEntry = namedtuple("CacheEntry", ["data_queue", "ingestion_config", "thread"]) + + sift_stream_builder: sift_stream_bindings.SiftStreamBuilderPy + stream_cache: Dict[str, "CacheEntry"] = {} + + def __init__(self, grpc_client: GrpcClient): + """ + Initialize the IngestionLowLevelClient. + + Args: + grpc_client: The gRPC client to use for making API calls. + """ + super().__init__(grpc_client=grpc_client) + # Rust GRPC client expects URI to have http(s):// prefix. + uri = grpc_client._config.uri + if not uri.startswith("http"): + uri = f"https://{uri}" if grpc_client._config.use_ssl else f"http://{uri}" + self.sift_stream_builder = sift_stream_bindings.SiftStreamBuilderPy( + uri=uri, + apikey=grpc_client._config.api_key, + ) + self.sift_stream_builder.enable_tls = grpc_client._config.use_ssl + # FD-177: Expose configuration for recovery strategy. + self.sift_stream_builder.recovery_strategy = ( + sift_stream_bindings.RecoveryStrategyPy.retry_only( + sift_stream_bindings.RetryPolicyPy.default() + ) + ) + + atexit.register(self.cleanup, timeout=0.1) + + def cleanup(self, timeout: float | None = None): + """ + Cleanup the ingestion threads. + + Args: + timeout: The timeout in seconds to wait for ingestion to complete. If None, will wait forever. + """ + for _, cache_entry in self.stream_cache.items(): + data_queue, ingestion_config, thread = cache_entry + # "None" value on the queue signals its loop to terminate. + if thread: + thread.join(timeout=timeout) + if thread.is_alive(): + logger.error( + f"Ingestion thread did not finish after {timeout} seconds. Forcing stop." + ) + thread.stop() + + async def get_ingestion_config_flows(self, ingestion_config_id: str) -> List[Flow]: + """ + Get the flows for an ingestion config. + """ + res = await self._grpc_client.get_stub(IngestionConfigServiceStub).GetIngestionConfig( + GetIngestionConfigRequest(ingestion_config_id=ingestion_config_id) + ) + res = cast(ListIngestionConfigFlowsResponse, res) + return [Flow._from_proto(flow) for flow in res.flows] + + async def list_ingestion_configs(self, filter_query: str) -> List[IngestionConfig]: + """ + List ingestion configs. + """ + res = await self._grpc_client.get_stub(IngestionConfigServiceStub).ListIngestionConfigs( + ListIngestionConfigsRequest(filter=filter_query) + ) + res = cast(ListIngestionConfigsResponse, res) + return [IngestionConfig._from_proto(config) for config in res.ingestion_configs] + + async def get_ingestion_config_id_from_client_key(self, client_key: str) -> str | None: + """ + Get the ingestion config id. + """ + filter_query = cel.equals("client_key", client_key) + ingestion_configs = await self.list_ingestion_configs(filter_query) + if not ingestion_configs: + return None + if len(ingestion_configs) > 1: + raise ValueError( + f"Expected 1 ingestion config for client key {client_key}, got {len(ingestion_configs)}" + ) + return ingestion_configs[0].id_ + + def _new_ingestion_thread( + self, ingestion_config_id: str, ingestion_config: IngestionConfigFormPy | None = None + ): + """Start a new ingestion thread. + This allows ingestion to happen in the background regardless of if the user is using the sync or async client + and without them having to set up threading themselves. We are using a thread vs asyncio since our + sync wrapper will block on incomlete tasks. + + Args: + ingestion_config_id: The id of the ingestion config for the flows this stream will ingest. Used to cache the stream. + ingestion_config: The ingestion config to use for ingestion. + """ + data_queue: Queue[List[IngestWithConfigDataStreamRequestPy]] = Queue() + existing = self.stream_cache.get(ingestion_config_id) + if existing: + existing_data_queue, existing_ingestion_config, existing_thread = existing + if existing_thread.is_alive(): + return existing_thread + else: + ingestion_config = existing_ingestion_config + # Re-use existing queue since ingest_flow has already put data on it. + data_queue = existing_data_queue + assert ingestion_config is not None # Appease mypy. + thread = IngestionThread(self.sift_stream_builder, data_queue, ingestion_config) + thread.start() + + return self.CacheEntry(data_queue, ingestion_config, thread) + + def _hash_flows(self, asset_name: str, flows: List[Flow]) -> str: + """ + Generate a client key that should be unique but deterministic for the given asset and flow configuration. + """ + # TODO: Taken from sift_py/ingestion/config/telemetry.py. Confirm intent from Marc. + m = hashlib.sha256() + m.update(asset_name.encode()) + for flow in sorted(flows, key=lambda f: f.name): + m.update(flow.name.encode()) + # Do not sort channels in alphabetical order since order matters. + for channel in flow.channels: + m.update(channel.name.encode()) + # Use api_format for data type since that should be consistent between languages. + m.update(channel.data_type.hash_str(api_format=True).encode()) + m.update((channel.description or "").encode()) + m.update((channel.unit or "").encode()) + if channel.bit_field_elements: + for bfe in sorted(channel.bit_field_elements, key=lambda bfe: bfe.index): + m.update(bfe.name.encode()) + m.update(str(bfe.index).encode()) + m.update(str(bfe.bit_count).encode()) + if channel.enum_types: + for enum_name, enum_key in sorted( + channel.enum_types.items(), key=lambda it: it[1] + ): + m.update(str(enum_key).encode()) + m.update(enum_name.encode()) + + return m.hexdigest() + + async def create_ingestion_config( + self, + *, + asset_name: str, + flows: List[Flow], + client_key: str | None = None, + organization_id: str | None = None, + ) -> str: + """ + Create an ingestion config. + + Args: + asset_name: The name of the asset to ingest to. + flows: The flows to ingest. + client_key: The client key to use for ingestion. If not provided, a new one will be generated. + organization_id: The organization id to use for ingestion. Only needed if the user is part of several organizations. + + Returns: + The id of the new or found ingestion config. + """ + ingestion_config_id = None + if client_key: + logger.debug(f"Getting ingestion config id for client key {client_key}") + ingestion_config_id = await self.get_ingestion_config_id_from_client_key(client_key) + if ingestion_config_id: + # Perform validation that the flows are valid for the ingestion config. + existing_flows = await self.get_ingestion_config_flows(ingestion_config_id) + for flow in flows: + if flow.name in {existing_flow.name for existing_flow in existing_flows}: + raise ValueError( + f"Flow {flow.name} already exists for ingestion client {client_key}" + ) + else: + client_key = self._hash_flows(asset_name, flows) + try: + logger.debug(f"Getting ingestion config id from generated client key {client_key}") + ingestion_config_id = await self.get_ingestion_config_id_from_client_key(client_key) + except ValueError: + logging.debug( + f"No ingestion config found for client key {client_key}. Creating new one." + ) + pass + + data_queue, ingestion_config, thread = ( + self.stream_cache.get(ingestion_config_id, (None, None, None)) + if ingestion_config_id + else (None, None, None) + ) + if not (thread and thread.is_alive()): + ingestion_config = IngestionConfigFormPy( + asset_name=asset_name, + flows=[flow._to_rust_config() for flow in flows], + client_key=client_key, + ) + + cache_entry = self._new_ingestion_thread(ingestion_config_id or "", ingestion_config) + if not ingestion_config_id: + # No ingestion config ID exists for client key but stream builder in ingestion thread should create it. + await cache_entry.thread.await_stream_build() + ingestion_config_id = await self.get_ingestion_config_id_from_client_key(client_key) + assert ingestion_config_id is not None, ( + "No ingestion config id found after building new stream. Likely server error." + ) + logger.debug(f"Built new stream for ingestion config {ingestion_config_id}") + self.stream_cache[ingestion_config_id] = cache_entry + + for flow in flows: + flow.ingestion_config_id = ingestion_config_id + + if not ingestion_config_id: + raise ValueError("No ingestion config id found") + return ingestion_config_id + + def wait_for_ingestion_to_complete(self, timeout: float | None = None): + """ + Blocks until all ingestion to complete. + + Args: + timeout: The timeout in seconds to wait for ingestion to complete. If None, will wait forever. + """ + logger.debug("Waiting for ingestion to complete") + self.cleanup(timeout) + + def ingest_flow( + self, + *, + flow: Flow, + timestamp: datetime, + channel_values: dict[str, Any], + organization_id: str | None = None, + ): + """ + Ingest a flow. This is a synchronous call that queues an ingestion request that will be processed asynchronously on a background thread. + + Args: + flow: The flow to ingest. + timestamp: The timestamp of the flow. + channel_values: The channel values to ingest. + organization_id: The organization id to use for ingestion. Only relevant if the user is part of several organizations. + """ + + if not flow.ingestion_config_id: + raise ValueError( + "Flow has no ingestion config id -- have you created an ingestion config for this flow?" + ) + cache_entry = self.stream_cache.get(flow.ingestion_config_id) + if not cache_entry: + raise ValueError( + f"Ingestion config {flow.ingestion_config_id} not found. Have you created an ingestion config for this flow?" + ) + rust_channel_values = [] + # Iterate through all expected channels for flow and convert to ingestion types (missing channels use a special empty type) + for channel in flow.channels: + val = channel_values.get(channel.name) + rust_channel_values.append(_to_rust_value(channel, val)) + req = IngestWithConfigDataStreamRequestPy( + ingestion_config_id=flow.ingestion_config_id, + run_id=flow.run_id or "", + flow=flow.name, + timestamp=to_rust_py_timestamp(timestamp), + channel_values=rust_channel_values, + end_stream_on_validation_error=False, + organization_id=organization_id or "", # This will be filled in by the server + ) + data_queue, ingestion_config, thread = cache_entry + assert data_queue is not None + # Put data on queue before potentially starting a new thread so it doesn't initially sleep waiting for data. + data_queue.put([req]) + if not (thread and thread.is_alive()): + # We previously had a thread for this ingestion config but it finished ingestion so create a new one. + self.stream_cache[flow.ingestion_config_id] = self._new_ingestion_thread( + flow.ingestion_config_id, ingestion_config + ) diff --git a/python/lib/sift_client/_internal/low_level_wrappers/rules.py b/python/lib/sift_client/_internal/low_level_wrappers/rules.py new file mode 100644 index 000000000..2600bc8e0 --- /dev/null +++ b/python/lib/sift_client/_internal/low_level_wrappers/rules.py @@ -0,0 +1,456 @@ +from __future__ import annotations + +import logging +from typing import Any, List, cast + +from sift.rules.v1.rules_pb2 import ( + BatchDeleteRulesRequest, + BatchGetRulesRequest, + BatchGetRulesResponse, + BatchUndeleteRulesRequest, + BatchUpdateRulesRequest, + BatchUpdateRulesResponse, + CalculatedChannelConfig, + ContextualChannels, + CreateRuleRequest, + CreateRuleResponse, + DeleteRuleRequest, + GetRuleRequest, + GetRuleResponse, + ListRulesRequest, + RuleAssetConfiguration, + RuleConditionExpression, + UndeleteRuleRequest, + UpdateConditionRequest, + UpdateRuleRequest, + UpdateRuleResponse, +) +from sift.rules.v1.rules_pb2 import ( + ChannelReference as ChannelReferenceProto, +) +from sift.rules.v1.rules_pb2_grpc import RuleServiceStub + +from sift_client._internal.low_level_wrappers.base import LowLevelClientBase +from sift_client.transport import GrpcClient, WithGrpcClient +from sift_client.types.channel import ChannelReference +from sift_client.types.rule import ( + Rule, + RuleAction, + RuleUpdate, +) + +# Configure logging +logger = logging.getLogger(__name__) + + +class RulesLowLevelClient(LowLevelClientBase, WithGrpcClient): + """ + Low-level client for the RulesAPI. + + This class provides a thin wrapper around the autogenerated bindings for the RulesAPI. + """ + + def __init__(self, grpc_client: GrpcClient): + """ + Initialize the RulesLowLevelClient. + + Args: + grpc_client: The gRPC client to use for making API calls. + """ + super().__init__(grpc_client) + + async def get_rule(self, rule_id: str | None = None, client_key: str | None = None) -> Rule: + """ + Get a rule by rule_id or client_key. + + Args: + rule_id: The rule ID to get. + client_key: The client key to get. + + Returns: + The Rule. + + Raises: + ValueError: If neither rule_id nor client_key is provided. + """ + request_kwargs: dict[str, Any] = {} + if rule_id is not None: + request_kwargs["rule_id"] = rule_id + if client_key is not None: + request_kwargs["client_key"] = client_key + + request = GetRuleRequest(**request_kwargs) + response = await self._grpc_client.get_stub(RuleServiceStub).GetRule(request) + grpc_rule = cast(GetRuleResponse, response).rule + return Rule._from_proto(grpc_rule) + + async def batch_get_rules( + self, rule_ids: list[str] | None = None, client_keys: list[str] | None = None + ) -> list[Rule]: + """ + Get multiple rules by rule_ids or client_keys. + + Args: + rule_ids: List of rule IDs to get. + client_keys: List of client keys to get. + + Returns: + List of Rules. + + Raises: + ValueError: If neither rule_ids nor client_keys is provided. + """ + if rule_ids is None and client_keys is None: + raise ValueError("Either rule_ids or client_keys must be provided") + + request_kwargs: dict[str, Any] = {} + if rule_ids is not None: + request_kwargs["rule_ids"] = rule_ids + if client_keys is not None: + request_kwargs["client_keys"] = client_keys + + request = BatchGetRulesRequest(**request_kwargs) + response = await self._grpc_client.get_stub(RuleServiceStub).BatchGetRules(request) + response = cast(BatchGetRulesResponse, response) + return [Rule._from_proto(rule) for rule in response.rules] + + async def create_rule( + self, + *, + name: str, + description: str, + organization_id: str | None = None, + client_key: str | None = None, + asset_ids: list[str] | None = None, + tag_ids: list[str] | None = None, + contextual_channels: list[str] | None = None, + is_external: bool, + expression: str, + channel_references: List[ChannelReference], + action: RuleAction, + ) -> Rule: + """ + Create a new rule. + + Args: + name: The name of the rule. + description: The description of the rule. + organization_id: The organization ID of the rule. + client_key: The client key of the rule. + asset_ids: The asset IDs of the rule. + contextual_channels: Optional contextual channels of the rule. + + Returns: + The rule ID of the created rule. + """ + # Convert rule to UpdateRuleRequest + expression_proto = RuleConditionExpression( + calculated_channel=CalculatedChannelConfig( + expression=expression, + channel_references={ + c.channel_reference: ChannelReferenceProto(name=c.channel_identifier) + for c in channel_references + }, + ) + ) + conditions_request = [ + UpdateConditionRequest( + expression=expression_proto, actions=[action._to_update_request()] + ) + ] + update_request = UpdateRuleRequest( + name=name, + description=description, + is_enabled=True, + organization_id=organization_id or "", + client_key=client_key, + is_external=is_external, + conditions=conditions_request, + asset_configuration=RuleAssetConfiguration( + asset_ids=asset_ids or [], + tag_ids=tag_ids or [], + ), + contextual_channels=ContextualChannels( + channels=[ChannelReferenceProto(name=c) for c in contextual_channels or []] + ), # type: ignore + ) + + request = CreateRuleRequest(update=update_request) + created_rule = cast( + CreateRuleResponse, + await self._grpc_client.get_stub(RuleServiceStub).CreateRule(request), + ) + return await self.get_rule(rule_id=created_rule.rule_id, client_key=client_key) + + def _update_rule_request_from_update( + self, rule: Rule, update: RuleUpdate, version_notes: str | None = None + ) -> UpdateRuleRequest: + """ + Create an update request from a rule and update. + + This helper exists because the Rule update protos need a pattern that is less generic than the normal update + mask pattern of other types. + """ + model_dump = update.model_dump(exclude_unset=True, exclude_none=False) + + update_dict = { + "version_notes": version_notes, + } + nontrivial_updates = [ + "expression", + "channel_references", + "action", + "contextual_channels", + "asset_ids", + "asset_tag_ids", + ] + # Need to manually copy fields that will be reset even if not provided in update dict. + copy_unset_fields = [ + "description", + ] + + # Populate the trivial fields first. + for updated_field, value in model_dump.items(): + if updated_field not in nontrivial_updates: + update_dict[updated_field] = value + # Populate the fields that weren't updated but will be reset if not provided in request. + for field in copy_unset_fields: + if field not in model_dump: + update_dict[field] = getattr(rule, field) + + # Special handling for the more complex fields. + # Also, these must always be set. + expression = model_dump.get("expression", rule.expression) + channel_references: List[ChannelReference] = ( + update.channel_references + if "channel_references" in model_dump + else rule.channel_references + ) or [] + action = update.action if "action" in model_dump else rule.action + if bool(expression) != bool(channel_references): + raise ValueError( + "Expression and channel_references must both be provided or both be None" + ) + expression_proto = RuleConditionExpression( + calculated_channel=CalculatedChannelConfig( + expression=expression, + channel_references={ + c.channel_reference: ChannelReferenceProto(name=c.channel_identifier) + for c in channel_references + }, + ) + if expression + else None + ) + conditions_request = [ + UpdateConditionRequest( + expression=expression_proto, + actions=[action._to_update_request()] if action else None, + ) + ] + update_dict["conditions"] = conditions_request # type: ignore + if "contextual_channels" in model_dump: + update_dict["contextual_channels"] = ContextualChannels( # type: ignore + channels=[ChannelReferenceProto(name=c) for c in update.contextual_channels or []] + ) + + # This always needs to be set, so handle the defaults. + update_dict["asset_configuration"] = RuleAssetConfiguration( # type: ignore + asset_ids=update.asset_ids if "asset_ids" in model_dump else rule.asset_ids or [], + tag_ids=update.asset_tag_ids + if "asset_tag_ids" in model_dump + else rule.asset_tag_ids or [], + ) + + update_request = UpdateRuleRequest( + rule_id=rule.id_, + **update_dict, # type: ignore + ) + + return update_request + + async def update_rule( + self, rule: Rule, update: RuleUpdate, version_notes: str | None = None + ) -> Rule: + """ + Update a rule. + + Args: + rule: The rule to update. + update: The update to apply. + version_notes: Notes to include in the rule version. + Returns: + The updated Rule. + """ + update.resource_id = rule.id_ + + update_request = self._update_rule_request_from_update(rule, update, version_notes) + + response = await self._grpc_client.get_stub(RuleServiceStub).UpdateRule(update_request) + updated_grpc_rule = cast(UpdateRuleResponse, response) + # Get the updated rule + return await self.get_rule(rule_id=updated_grpc_rule.rule_id) + + async def batch_update_rules(self, rules: list[RuleUpdate]) -> BatchUpdateRulesResponse: + """ + Batch update rules. + + Args: + rules: List of rule updates to apply. + + Returns: + The batch update response. + """ + update_requests = [] + for rule_update in rules: + rule = await self.get_rule(rule_id=rule_update.resource_id) + request = self._update_rule_request_from_update(rule, rule_update) + update_requests.append(request) + + request = BatchUpdateRulesRequest(rules=update_requests) # type: ignore + response = await self._grpc_client.get_stub(RuleServiceStub).BatchUpdateRules(request) + return cast(BatchUpdateRulesResponse, response) + + async def archive_rule(self, rule_id: str | None = None, client_key: str | None = None) -> None: + """ + Archive a rule. + + Args: + rule_id: The rule ID to archive. + client_key: The client key to archive. + + Raises: + ValueError: If neither rule_id nor client_key is provided. + """ + if rule_id is None and client_key is None: + raise ValueError("Either rule_id or client_key must be provided") + + request_kwargs: dict[str, Any] = {} + if rule_id is not None: + request_kwargs["rule_id"] = rule_id + if client_key is not None: + request_kwargs["client_key"] = client_key + + request = DeleteRuleRequest(**request_kwargs) + await self._grpc_client.get_stub(RuleServiceStub).ArchiveRule(request) + + async def batch_archive_rules( + self, rule_ids: List[str] | None = None, client_keys: List[str] | None = None + ) -> None: + """ + Batch archive rules. + + Args: + rule_ids: List of rule IDs to archive. + client_keys: List of client keys to delete. If both are provided, rule_ids will be used. + + Raises: + ValueError: If neither rule_ids nor client_keys is provided. + """ + if rule_ids is None and client_keys is None: + raise ValueError("Either rule_ids or client_keys must be provided") + + request_kwargs: dict[str, Any] = {} + if rule_ids is not None: + request_kwargs["rule_ids"] = rule_ids + if client_keys is not None: + request_kwargs["client_keys"] = client_keys + + request = BatchDeleteRulesRequest(**request_kwargs) + await self._grpc_client.get_stub(RuleServiceStub).BatchDeleteRules(request) + + async def restore_rule(self, rule_id: str | None = None, client_key: str | None = None) -> Rule: + """ + Restore a rule. + + Args: + rule_id: The rule ID to restore. + client_key: The client key to restore. + + Returns: + The restored Rule. + + Raises: + ValueError: If neither rule_id nor client_key is provided. + """ + if rule_id is None and client_key is None: + raise ValueError("Either rule_id or client_key must be provided") + + request_kwargs: dict[str, Any] = {} + if rule_id is not None: + request_kwargs["rule_id"] = rule_id + if client_key is not None: + request_kwargs["client_key"] = client_key + + request = UndeleteRuleRequest(**request_kwargs) + await self._grpc_client.get_stub(RuleServiceStub).UndeleteRule(request) + # Get the restored rule + return await self.get_rule(rule_id=rule_id, client_key=client_key) + + async def batch_restore_rules( + self, rule_ids: List[str] | None = None, client_keys: List[str] | None = None + ) -> None: + """ + Batch restore rules. + + Args: + rule_ids: List of rule IDs to restore. + client_keys: List of client keys to restore. + + Raises: + ValueError: If neither rule_ids nor client_keys is provided. + """ + if rule_ids is None and client_keys is None: + raise ValueError("Either rule_ids or client_keys must be provided") + + request_kwargs: dict[str, Any] = {} + if rule_ids is not None: + request_kwargs["rule_ids"] = rule_ids + if client_keys is not None: + request_kwargs["client_keys"] = client_keys + + request = BatchUndeleteRulesRequest(**request_kwargs) + await self._grpc_client.get_stub(RuleServiceStub).BatchUndeleteRules(request) + + async def list_rules( + self, + *, + filter_query: str | None = None, + order_by: str | None = None, + page_size: int | None = None, + page_token: str | None = None, + ) -> tuple[List[Rule], str | None]: + """ + List rules. + """ + request_kwargs: dict[str, Any] = {} + if filter_query is not None: + request_kwargs["filter"] = filter_query + if order_by is not None: + request_kwargs["order_by"] = order_by + if page_size is not None: + request_kwargs["page_size"] = page_size + if page_token is not None: + request_kwargs["page_token"] = page_token + + request = ListRulesRequest(**request_kwargs) + response = await self._grpc_client.get_stub(RuleServiceStub).ListRules(request) + return [Rule._from_proto(rule) for rule in response.rules], response.next_page_token + + async def list_all_rules( + self, + *, + filter_query: str | None = None, + order_by: str | None = None, + max_results: int | None = None, + page_size: int | None = None, + ) -> List[Rule]: + """ + List all rules. + """ + return await self._handle_pagination( + self.list_rules, + kwargs={"filter_query": filter_query}, + page_size=page_size, + order_by=order_by, + max_results=max_results, + ) diff --git a/python/lib/sift_client/_internal/low_level_wrappers/runs.py b/python/lib/sift_client/_internal/low_level_wrappers/runs.py index c25655d31..66631c7ee 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/runs.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/runs.py @@ -197,12 +197,12 @@ async def update_run(self, run: Run, update: RunUpdate) -> Run: grpc_run = cast(UpdateRunResponse, response).run return Run._from_proto(grpc_run) - async def delete_run(self, run_id: str) -> None: + async def archive_run(self, run_id: str) -> None: """ - Delete a run. + Archive a run. Args: - run_id: The ID of the run to delete. + run_id: The ID of the run to archive. Raises: ValueError: If run_id is not provided. diff --git a/python/lib/sift_client/_tests/__init__.py b/python/lib/sift_client/_tests/__init__.py index e69de29bb..d83c5d20f 100644 --- a/python/lib/sift_client/_tests/__init__.py +++ b/python/lib/sift_client/_tests/__init__.py @@ -0,0 +1,6 @@ +import logging + + +def setup_logger(): + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) diff --git a/python/lib/sift_client/_tests/integrated/calculated_channels.py b/python/lib/sift_client/_tests/integrated/calculated_channels.py index 200472981..ce887fb47 100644 --- a/python/lib/sift_client/_tests/integrated/calculated_channels.py +++ b/python/lib/sift_client/_tests/integrated/calculated_channels.py @@ -43,7 +43,7 @@ async def main(): # Find assets to work with asset = client.assets.find(name="NostromoLV426") - asset_id = asset.id + asset_id = asset.id_ print(f"Using asset: {asset.name} (ID: {asset_id})") # Create example calculated channels that will be unique to this test run in case things don't cleanup. @@ -69,7 +69,7 @@ async def main(): ) created_channels.append(calculated_channel) print( - f"Created calculated channel: {calculated_channel.name} (ID: {calculated_channel.id})" + f"Created calculated channel: {calculated_channel.name} (ID: {calculated_channel.id_})" ) # Find the channels we just created @@ -236,7 +236,7 @@ async def main(): assert updated_channel_7.tag_ids == [], f"Tag IDs update failed: {updated_channel_7.tag_ids}" versions = client.calculated_channels.list_versions( - calculated_channel_id=channel_1.id, + calculated_channel_id=channel_1.id_, limit=10, ) print(f"Found {len(versions)} versions for {created_channels[0].name}") diff --git a/python/lib/sift_client/_tests/integrated/channels.py b/python/lib/sift_client/_tests/integrated/channels.py index bb541e925..4b4247e5c 100644 --- a/python/lib/sift_client/_tests/integrated/channels.py +++ b/python/lib/sift_client/_tests/integrated/channels.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd +import pyarrow as pa from sift_client.client import SiftClient @@ -15,23 +16,23 @@ async def main(): client = SiftClient(grpc_url=grpc_url, api_key=api_key, rest_url=rest_url) asset = client.assets.find(name="NostromoLV426") - asset_id = asset.id + asset_id = asset.id_ print(f"Using asset: {asset.name} (ID: {asset_id})") # List runs for this asset - runs = asset.runs() + runs = asset.runs print( f"Found {len(runs)} run(s): {[run.name for run in runs]} for asset {asset.name} (ID: {asset_id})" ) # Pick one. run = runs[0] - run_id = run.id + run_id = run.id_ print(f"Using run: {run.name} (ID: {run_id})") # List other assets for this run. - all_assets = run.assets() - other_assets = [asset for asset in all_assets if asset.id != asset_id] + all_assets = run.assets + other_assets = [asset for asset in all_assets if asset.id_ != asset_id] print( f"Found {len(other_assets)} other asset(s): {other_assets} for run {run.name} (ID: {run_id})" ) @@ -39,7 +40,7 @@ async def main(): # List channels for this asset (find a run w/ data) channels = [] for run in runs: - asset_channels = asset.channels(run_id=run.id, limit=10) + asset_channels = asset.channels(run_id=run.id_, limit=10) other_channels = [] for c in asset_channels: if c.name in {"voltage", "gpio", "temperature", "mainmotor.velocity"}: @@ -49,7 +50,7 @@ async def main(): if len(channels) > 3: print( - f"Found {len(channels)} channel(s): {[channel.identifier for channel in channels]} for asset {asset.name} on run {run.name}" + f"Found {len(channels)} channel(s): {[channel.name for channel in channels]} for asset {asset.name} on run {run.name}" ) if len(other_channels) > 0: print( @@ -157,6 +158,19 @@ async def main(): ) no_time_time_repeat = time.perf_counter() - perf_start + # Test 7: Get data as arrow + print("\nTest 7: Get data as arrow") + perf_start = time.perf_counter() + channel_data_arrow = client.channels.get_data_as_arrow( + channels=channels, + end_time=fake_no_end_time, + ) + arrow_time = time.perf_counter() - perf_start + for i, (channel_name, data) in enumerate(channel_data_arrow.items()): + print( + f"{i}: {channel_name}: {len(data)} points. Avg: {pa.compute.mean(data[channel_name])}" + ) + # Summary of cache performance print("\n=== Cache Performance Summary ===") print(f"Original call: {first_time:.4f} seconds") @@ -178,11 +192,11 @@ async def main(): print( f"No time range repeat: {no_time_time_repeat:.4f} seconds ({(no_time_time / no_time_time_repeat):.1f}x faster)" ) + print(f"Arrow: {arrow_time:.4f} seconds ({(arrow_time / no_time_time_repeat):.1f}x faster)") assert exact_time < first_time assert subset_time < first_time assert extended_time < first_time assert different_time < first_time - assert no_time_time > first_time assert no_time_time_repeat < no_time_time diff --git a/python/lib/sift_client/_tests/integrated/ingestion.py b/python/lib/sift_client/_tests/integrated/ingestion.py new file mode 100644 index 000000000..4045eae98 --- /dev/null +++ b/python/lib/sift_client/_tests/integrated/ingestion.py @@ -0,0 +1,215 @@ +import asyncio +import math +import os +import random +import time +from datetime import datetime, timedelta + +from sift_client._tests import setup_logger +from sift_client.client import SiftClient +from sift_client.transport import SiftConnectionConfig +from sift_client.types.channel import ( + Channel, + ChannelBitFieldElement, + ChannelDataType, +) +from sift_client.types.ingestion import Flow + +setup_logger() + + +async def main(): + grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") + api_key = os.getenv("SIFT_API_KEY", "") + rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") + client = SiftClient( + connection_config=SiftConnectionConfig( + grpc_url=grpc_url, + api_key=api_key, + rest_url=rest_url, + use_ssl=True, + cert_via_openssl=True, + ) + ) + + asset = "ian-test-asset" + + # TODO:Get user id from current user + previously_created_runs = client.runs.list( + name_regex="test-run-.*", created_by_user_id="1eba461b-fa36-4e98-8fe8-ff32d3e43a6e" + ) + if previously_created_runs: + print(f" Deleting previously created runs: {previously_created_runs}") + for run in previously_created_runs: + print(f" Deleting run: {run.name}") + client.runs.archive(run=run) + + run = client.runs.create( + name=f"test-run-{datetime.now().timestamp()}", + description="A test run created via the API", + tags=["api-created", "test"], + ) + + regular_flow = Flow( + name="test-flow", + channels=[ + Channel(name="test-channel", data_type=ChannelDataType.DOUBLE), + Channel( + name="test-enum-channel", + data_type=ChannelDataType.ENUM, + enum_types={"enum1": 1, "enum2": 2}, + ), + ], + ) + regular_flow.add_channel( + Channel( + name="test-bit-field-channel", + data_type=ChannelDataType.BIT_FIELD, + bit_field_elements=[ + ChannelBitFieldElement(name="12v", index=0, bit_count=4), + ChannelBitFieldElement(name="charge", index=4, bit_count=2), + ChannelBitFieldElement(name="led", index=6, bit_count=1), + ChannelBitFieldElement(name="heater", index=7, bit_count=1), + ], + ) + ) + + highspeed_flow = Flow( + name="highspeed-flow", + channels=[ + Channel(name="highspeed-channel", data_type=ChannelDataType.DOUBLE), + ], + ) + # This seals the flow and ingestion config + config_id = await client.async_.ingestion.create_ingestion_config( + asset_name=asset, + run_id=run.id_, + flows=[regular_flow, highspeed_flow], + ) + print(f"config_id: {config_id}") + try: + regular_flow.add_channel(Channel(name="test-channel", data_type=ChannelDataType.DOUBLE)) + except ValueError as e: + assert repr(e) == "ValueError('Cannot add a channel to a flow after creation')" + + other_asset_flows = [ + Flow( + name="new-asset-flow", + channels=[ + # Same channel name as the regular flow, but on a different asset. + Channel(name="test-channel", data_type=ChannelDataType.DOUBLE), + ], + ) + ] + await client.async_.ingestion.create_ingestion_config( + asset_name="test-asset-ian2", + run_id=run.id_, + flows=other_asset_flows, + ) + sleep_time = 0.05 # Time between outer loop iterations to simulate real-time latency between ingestion calls. + simulated_duration = 50 + fake_hs_rate = 50 # Hz + fake_hs_period = 1 / fake_hs_rate + start = datetime.now() + for i in range(simulated_duration): + now = start + timedelta(seconds=i) + regular_flow.ingest( + timestamp=now, + channel_values={ + "test-channel": 3.0 * math.sin(2 * math.pi * fake_hs_rate * i + 0.07), + "test-enum-channel": i % 2 + 1, + "test-bit-field-channel": { + "12v": random.randint(3, 13), + "charge": random.randint(1, 3), + "led": random.choice([0, 1]), + "heater": random.choice([0, 1]), + }, + }, + ) + for j in range(fake_hs_rate): + val = 3.0 * math.sin(2 * math.pi * fake_hs_rate * (i + j * 0.001) + 0) + timestamp = now + timedelta(milliseconds=j * fake_hs_period * 1000) + channel_values = { + "highspeed-channel": val, + } + # Alternative way to ingest + client.ingestion.ingest( + flow=highspeed_flow, timestamp=timestamp, channel_values=channel_values + ) + time.sleep(sleep_time) + + other_asset_flows[0].ingest( + timestamp=start + timedelta(seconds=simulated_duration), + channel_values={ + "test-channel": -6.66, + }, + ) + + # Test ingestion of a flow without all channels specified + try: + regular_flow.ingest( + timestamp=start + timedelta(seconds=simulated_duration), + channel_values={ + "test-channel": 0, + "test-enum-channel": 2, + # "test-bit-field-channel": bytes([0b01010101]), + }, + ) + except ValueError as e: + assert "Expected all channels in flow to have a data point at same time." in repr(e) + + # Test ingestion of a bad enum value (string and int) + try: + regular_flow.ingest( + timestamp=start + timedelta(seconds=simulated_duration), + channel_values={ + "test-channel": 0, + "test-enum-channel": -3, + "test-bit-field-channel": bytes([0b01010101]), + }, + ) + except ValueError as e: + assert "Could not find enum value: -3 in enum options: {'enum1': 1, 'enum2': 2}" in repr(e) + try: + regular_flow.ingest( + timestamp=start + timedelta(seconds=simulated_duration), + channel_values={ + "test-channel": 0, + "test-enum-channel": "nonexistent-enum", + "test-bit-field-channel": bytes([0b01010101]), + }, + ) + except ValueError as e: + assert ( + "Could not find enum value: nonexistent-enum in enum options: {'enum1': 1, 'enum2': 2}" + in repr(e) + ) + + client.async_.ingestion.wait_for_ingestion_to_complete(timeout=2) + end = datetime.now() + # Test ingesting more data after letting a thread finish. Also exercise ingesting bitfield values as bytes. + time.sleep(1) + print("Restarting ingestion") + regular_flow.ingest( + timestamp=start + timedelta(seconds=simulated_duration + 1), + channel_values={ + "test-channel": 7.77, + "test-enum-channel": 1, + "test-bit-field-channel": bytes([0b11111111]), + }, + ) + # Wait less time than threads nominal no_data_timeout so we can exercise forced cleanup. + client.async_.ingestion.wait_for_ingestion_to_complete(timeout=0.01) + client.runs.archive(run=run.id_) + + num_datapoints = fake_hs_rate * len( + highspeed_flow.channels + ) * simulated_duration + simulated_duration * len(regular_flow.channels) + print(f"Ingestion time: {end - start} seconds") + print(f"Ingested {num_datapoints} datapoints") + total_time = (end - start).total_seconds() + print(f"Ingestion rate: {num_datapoints / total_time:.2f} datapoints/second") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/lib/sift_client/_tests/integrated/rules.py b/python/lib/sift_client/_tests/integrated/rules.py new file mode 100644 index 000000000..b41eb77c4 --- /dev/null +++ b/python/lib/sift_client/_tests/integrated/rules.py @@ -0,0 +1,289 @@ +import os +from datetime import datetime + +from sift_client.client import SiftClient + +# Import sift_client types for calculated channels and rules +from sift_client.types import ( + ChannelReference, + RuleAction, + RuleAnnotationType, + RuleUpdate, +) + +""" +Comprehensive test script for rules with extensive update field exercises. + +This test demonstrates all available update fields for rules: +- name: Update the rule name +- description: Update the rule description +- expression: Update the rule expression +- channel_references: Update channel references (must be updated with expression) +- action: Update the rule action (annotation, notification, webhook) +- tag_ids: Update associated tags (TBD) +- contextual_channels: Update contextual channels +- version_notes: Update version notes + +The test also includes: +- Edge case testing (invalid expressions) +- Batch operations demonstration +- Comprehensive validation +- Archive operations + + +If we keep it as a test, we should ideally have a setup that populates data, and then ensure we teardown all the test assets/channels/rules etc. +""" + + +def main(): + grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") + api_key = os.getenv("SIFT_API_KEY", "") + rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") + client = SiftClient(grpc_url=grpc_url, api_key=api_key, rest_url=rest_url) + + asset = client.assets.find(name="NostromoLV426") + asset_id = asset.id_ + print(f"Using asset: {asset.name} (ID: {asset_id})") + + unique_name_suffix = datetime.now().strftime("%Y%m%d%H%M%S") + num_rules = 8 + print(f"\n=== Creating {num_rules} rules with unique suffix: {unique_name_suffix} ===") + created_rules = [] + for i in range(num_rules): + rule = client.rules.create( + name=f"test_rule_{unique_name_suffix}_{i}", + description=f"Test rule {i} - initial description", + expression="$1 > 0.1", # Simple threshold check + channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier="mainmotor.velocity"), + ], + action=RuleAction.annotation( + annotation_type=RuleAnnotationType.DATA_REVIEW, + tags=["test", "initial"], + default_assignee_user_id=None, + ), + asset_ids=[asset_id], + ) + created_rules.append(rule) + print(f"Created rule: {rule.name} (ID: {rule.id_})") + + # Find the rules we just created + search_results = client.rules.list( + name_regex=f"test_rule_{unique_name_suffix}.*", + ) + assert len(search_results) == num_rules, ( + f"Expected {num_rules} created rules, got {len(search_results)}" + ) + + print("\n=== Testing comprehensive update scenarios ===") + + # Test 1: Update expression and channel references together + print("\n--- Test 1: Update expression and channel references ---") + rule_1 = created_rules[0] + rule_1_model_dump = rule_1.model_dump() + updated_rule_1 = rule_1.update( + RuleUpdate( + expression="$1 > 0.5", # Higher threshold + channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier="mainmotor.velocity"), + ], + ) + ) + updated_rule_1_model_dump = updated_rule_1.model_dump() + print(f"Updated {updated_rule_1.name}: expression = {updated_rule_1.expression}") + + # Test 2: Update description + print("\n--- Test 2: Update description ---") + rule_2 = created_rules[1] + updated_rule_2 = rule_2.update( + RuleUpdate( + description="Updated description with more details about velocity-to-voltage ratio monitoring", + ) + ) + print(f"Updated {updated_rule_2.name}: description = {updated_rule_2.description}") + + # Test 3: Update action (change annotation type and tags) + print("\n--- Test 3: Update action ---") + rule_3 = created_rules[2] + updated_rule_3 = rule_3.update( + RuleUpdate( + action=RuleAction.annotation( + annotation_type=RuleAnnotationType.PHASE, + tags=["updated", "phase", "alert"], + default_assignee_user_id=rule_3.created_by_user_id, + ), + ) + ) + print(f"Updated {updated_rule_3.name}: action type = {updated_rule_3.action.action_type}") + print(f" - annotation type: {updated_rule_3.action.annotation_type}") + print(f" - tags: {updated_rule_3.action.tags}") + print(f" - assignee: {updated_rule_3.action.default_assignee_user_id}") + + # Test 4: Update name + print("\n--- Test 4: Update name ---") + rule_4 = created_rules[3] + new_name = f"renamed_rule_{unique_name_suffix}_4" + updated_rule_4 = rule_4.update( + RuleUpdate( + name=new_name, + ) + ) + print(f"Updated {rule_4.name} -> {updated_rule_4.name}") + + # Test 5: Update multiple fields at once + print("\n--- Test 5: Update multiple fields simultaneously ---") + rule_5 = created_rules[4] + updated_rule_5 = rule_5.update( + RuleUpdate( + description="Multi-field update test", + ), + version_notes="Updated via multi-field update", + ) + print(f"Updated {updated_rule_5.name}:") + print(f" - description: {updated_rule_5.description}") + print( + f" - version_notes: {updated_rule_5.rule_version.version_notes if updated_rule_5.rule_version else None}" + ) + + # Test 6: Update with complex expression + print("\n--- Test 6: Update with complex expression ---") + rule_6 = created_rules[5] + updated_rule_6 = rule_6.update( + RuleUpdate( + expression="$1 > 0.3 && $1 < 0.8", # Range check + channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier="mainmotor.velocity"), + ], + ) + ) + print(f"Updated {updated_rule_6.name}: complex expression = {updated_rule_6.expression}") + + # Test 7: Update action to notification type + print("\n--- Test 7: Update action to notification ---") + rule_7 = created_rules[6] + updated_rule_7 = rule_7 + # Note: Notification actions are not supported yet. + # updated_rule_7 = rule_7.update( + # RuleUpdate( + # action=RuleAction.notification( + # notify_recipients=[rule_7.created_by_user_id] + # ), + # ) + # ) + # print(f"Updated {updated_rule_7.name}: action type = {updated_rule_7.action.action_type}") + # print(f" - notification recipients: {updated_rule_7.action.notification_recipients}") + + # Test 8: Update tag_ids and contextual_channels + print("\n--- Test 8: Update tag_ids and contextual_channels ---") + rule_8 = created_rules[7] + updated_rule_8 = rule_8.update( + RuleUpdate( + # tag_ids=["tag-123", "tag-456"], # Example tag IDs # TODO: Where are these IDs supposed to come from? They're supposed to be uuids? {grpc_message:"invalid argument: invalid input syntax for type uuid: \"tag-123\" + contextual_channels=["temperature", "pressure"], # Example contextual channels + ) + ) + print(f"Updated {updated_rule_8.name}:") + print(f" - asset_tag_ids: {updated_rule_8.asset_tag_ids}") + print(f" - contextual_channels: {updated_rule_8.contextual_channels}") + + # Test 8b: Edge case - Update with invalid expression (should fail gracefully) + print("\n--- Test 8b: Edge case - Invalid expression test ---") + try: + invalid_update = rule_8.update( + RuleUpdate( + expression="invalid_expression", + channel_references=[ + ChannelReference( + channel_reference="$1", channel_identifier="mainmotor.velocity" + ), + ], + ) + ) + print(f"Invalid expression update succeeded (unexpected): {invalid_update.expression}") + except Exception as e: + print(f"Invalid expression update failed as expected: {e}") + + # Test 9: Batch operations demonstration + print("\n--- Test 9: Batch operations demonstration ---") + all_updated_rules = [ + updated_rule_1, + updated_rule_2, + updated_rule_3, + updated_rule_4, + updated_rule_5, + updated_rule_6, + updated_rule_7, + updated_rule_8, + ] + + # Batch get the updated rules + rule_ids = [rule.id_ for rule in all_updated_rules] + batch_rules = client.rules.batch_get(rule_ids=rule_ids) + print(f"Batch retrieved {len(batch_rules)} rules:") + for rule in batch_rules: + print(f" - {rule.name}: {rule.expression}") + + # Test 10: Archive rules + print("\n--- Test 10: Archive rules ---") + client.rules.archive(rules=created_rules) + + print("\n=== Test Summary ===") + print(f"Created: {len(created_rules)} rules") + print(f"Updated: {len(all_updated_rules)} rules") + + # Verify all rules were processed + assert len(created_rules) == num_rules, ( + f"Expected {num_rules} created rules, got {len(created_rules)}" + ) + assert len(all_updated_rules) == num_rules, ( + f"Expected {num_rules} updated rules, got {len(all_updated_rules)}" + ) + + # Additional validation + print("\n=== Validation Checks ===") + + # Verify that updates actually changed the values + assert updated_rule_1.expression == "$1 > 0.5", ( + f"Expression update failed: {updated_rule_1.expression}" + ) + # For update 1, also verify that the fields that were not updated are not reset. + assert updated_rule_1_model_dump["description"] == rule_1_model_dump["description"], ( + f"Expected no description change, got {rule_1_model_dump['description']} -> {updated_rule_1.description}" + ) + assert ( + updated_rule_1_model_dump["channel_references"] == rule_1_model_dump["channel_references"] + ), ( + f"Expected no channel references change, got {rule_1_model_dump['channel_references']} -> {updated_rule_1.channel_references}" + ) + assert updated_rule_1_model_dump["asset_ids"] == rule_1_model_dump["asset_ids"], ( + f"Expected no asset IDs change, got {rule_1_model_dump['asset_ids']} -> {updated_rule_1.asset_ids}" + ) + assert updated_rule_1_model_dump["asset_tag_ids"] == rule_1_model_dump["asset_tag_ids"], ( + f"Expected no tag IDs change, got {rule_1_model_dump['asset_tag_ids']} -> {updated_rule_1.asset_tag_ids}" + ) + assert ( + updated_rule_1_model_dump["contextual_channels"] == rule_1_model_dump["contextual_channels"] + ), f"Contextual channels update failed: {updated_rule_1.contextual_channels}" + assert "more details" in updated_rule_2.description, ( + f"Description update failed: {updated_rule_2.description}" + ) + assert updated_rule_3.action.annotation_type == RuleAnnotationType.PHASE, ( + f"Action update failed: {updated_rule_3.action.annotation_type}" + ) + assert updated_rule_4.name == new_name, f"Name update failed: {updated_rule_4.name}" + + assert updated_rule_6.expression == "$1 > 0.3 && $1 < 0.8", ( + f"Complex expression update failed: {updated_rule_6.expression}" + ) + # assert updated_rule_7.action.action_type == RuleActionType.NOTIFICATION, f"Action type update failed: {updated_rule_7.action.action_type}" + # assert len(updated_rule_8.tag_ids) == 2, f"Tag IDs update failed: {updated_rule_8.tag_ids}" + assert len(updated_rule_8.contextual_channels) == 2, ( + f"Contextual channels update failed: {updated_rule_8.contextual_channels}" + ) + + print("All validation checks passed!") + print("\n=== Test completed successfully ===") + + +if __name__ == "__main__": + main() diff --git a/python/lib/sift_client/_tests/integrated/runs.py b/python/lib/sift_client/_tests/integrated/runs.py index 1209107a8..a620f713c 100644 --- a/python/lib/sift_client/_tests/integrated/runs.py +++ b/python/lib/sift_client/_tests/integrated/runs.py @@ -32,23 +32,23 @@ async def main(): # Use a known asset to fetch a run. asset = client.assets.find(name="NostromoLV426") - asset_id = asset.id + asset_id = asset.id_ print(f"Using asset: {asset.name} (ID: {asset_id})") # List runs for this asset - runs = asset.runs() + runs = asset.runs print( f"Found {len(runs)} run(s): {[run.name for run in runs]} for asset {asset.name} (ID: {asset_id})" ) # Pick one. run = runs[0] - run_id = run.id + run_id = run.id_ print(f"Using run: {run.name} (ID: {run_id})") # List other assets for this run. - all_assets = run.assets() - other_assets = [asset for asset in all_assets if asset.id != asset_id] + all_assets = run.assets + other_assets = [asset for asset in all_assets if asset.id_ != asset_id] print( f"Found {len(other_assets)} other asset(s): {other_assets} for run {run.name} (ID: {run_id})" ) @@ -58,7 +58,7 @@ async def main(): runs = client.runs.list(limit=5) print(f" Found {len(runs)} runs:") for run in runs: - print(f" - {run.name} (ID: {run.id}), Organization ID: {run.organization_id}") + print(f" - {run.name} (ID: {run.id_}), Organization ID: {run.organization_id}") # Example 2: Test different filter options print("\n2. Testing different filter options...") @@ -77,7 +77,7 @@ async def main(): runs = client.runs.list(name=run_name, limit=5) print(f" Found {len(runs)} runs with exact name '{run_name}':") for run in runs: - print(f" - {run.name} (ID: {run.id})") + print(f" - {run.name} (ID: {run.id_})") # 2b: Filter by name containing text print("\n 2b. Filter by name containing text...") @@ -214,7 +214,7 @@ async def main(): print(f" Deleting previously created runs: {previously_created_runs}") for run in previously_created_runs: print(f" Deleting run: {run.name}") - client.runs.delete(run=run) + client.runs.archive(run=run) new_run = client.runs.create( name=f"Example Test Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", @@ -226,7 +226,7 @@ async def main(): client_key=f"example-run-key-{datetime.now().timestamp()}", metadata=metadata, ) - print(f" Created run: {new_run.name} (ID: {new_run.id})") + print(f" Created run: {new_run.name} (ID: {new_run.id_})") print(f" Client key: {new_run.client_key}") print(f" Tags: {new_run.tags}") @@ -281,8 +281,8 @@ async def main(): print("\n7. Deleting a run") run_to_delete = new_run print(f" Deleting run: {run_to_delete.name}") - client.runs.delete(run=run_to_delete) - print(f" Successfully deleted run: {run_to_delete.name}") + client.runs.archive(run=run_to_delete) + print(f" Successfully archived run: {run_to_delete.name}") if __name__ == "__main__": diff --git a/python/lib/sift_client/client.py b/python/lib/sift_client/client.py index c5623a5c9..141d20b91 100644 --- a/python/lib/sift_client/client.py +++ b/python/lib/sift_client/client.py @@ -6,8 +6,13 @@ AssetsAPIAsync, CalculatedChannelsAPI, CalculatedChannelsAPIAsync, + ChannelsAPI, + ChannelsAPIAsync, + IngestionAPIAsync, PingAPI, PingAPIAsync, + RulesAPI, + RulesAPIAsync, RunsAPI, RunsAPIAsync, ) @@ -70,6 +75,15 @@ class SiftClient( calculated_channels: CalculatedChannelsAPI """Instance of the Calculated Channels API for making synchronous requests.""" + channels: ChannelsAPI + """Instance of the Channels API for making synchronous requests.""" + + ingestion: IngestionAPIAsync + """Instance of the Ingestion API for making synchronous requests.""" + + rules: RulesAPI + """Instance of the Rules API for making synchronous requests.""" + runs: RunsAPI """Instance of the Runs API for making synchronous requests.""" @@ -115,12 +129,19 @@ def __init__( self.ping = PingAPI(self) self.assets = AssetsAPI(self) self.calculated_channels = CalculatedChannelsAPI(self) + self.channels = ChannelsAPI(self) + self.ingestion = IngestionAPIAsync(self) + self.rules = RulesAPI(self) self.runs = RunsAPI(self) + # Accessor for the asynchronous APIs self.async_ = AsyncAPIs( ping=PingAPIAsync(self), assets=AssetsAPIAsync(self), calculated_channels=CalculatedChannelsAPIAsync(self), + channels=ChannelsAPIAsync(self), + ingestion=IngestionAPIAsync(self), + rules=RulesAPIAsync(self), runs=RunsAPIAsync(self), ) diff --git a/python/lib/sift_client/examples/generic_workflow_example.py b/python/lib/sift_client/examples/generic_workflow_example.py new file mode 100644 index 000000000..c060b2a68 --- /dev/null +++ b/python/lib/sift_client/examples/generic_workflow_example.py @@ -0,0 +1,119 @@ +import asyncio +import os +from datetime import datetime + +from sift_client.client import SiftClient + +# Import sift_client types for calculated channels and rules +from sift_client.types import ( + CalculatedChannelUpdate, + ChannelReference, + RuleAction, + RuleAnnotationType, + RuleUpdate, +) + +""" +Placeholder for future examples. FD-67 +""" + + +async def main(): + grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") + api_key = os.getenv("SIFT_API_KEY", "") + rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") + client = SiftClient(grpc_url=grpc_url, api_key=api_key, rest_url=rest_url) + + asset = client.assets.find(name="NostromoLV426") + asset_id = asset.id_ + print("Found asset", asset.name) + + calculated_channels = client.calculated_channels.list( + name_regex="velocity_per.*", + asset_id=asset_id, + ) + updated = False + calculated_channel = None + if calculated_channels: + print(f"Found calculated channels: {[cc.name for cc in calculated_channels]}") + for cc in calculated_channels: + if cc.name == "velocity_per_voltage": + calculated_channel = cc.update( + CalculatedChannelUpdate( + expression="$1 / $2 + 0.1", + expression_channel_references=cc.channel_references, + ) + ) + print("Updated calculated channel", calculated_channel) + else: + # Create a calculated channel that divides mainmotor.velocity by voltage + print("\nCreating calculated channel...") + calculated_channel = client.calculated_channels.create( + name="velocity_per_voltage", + description="Ratio of mainmotor velocity to voltage", + expression="$1 / $2", # $1 = mainmotor.velocity, $2 = voltage + channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier="mainmotor.velocity"), + ChannelReference(channel_reference="$2", channel_identifier="voltage"), + ], + units="velocity/voltage", + asset_ids=[asset_id], + user_notes="Created to monitor velocity-to-voltage ratio", + ) + print( + f"Created calculated channel: {calculated_channel.name} (ID: {calculated_channel.calculated_channel_id})" + ) + + # Create a rule that creates an annotation when the ratio is above 0.1 + rule_search = "high_velocity_voltage" + print(f"Looking for rule containing {rule_search}") + rules = client.rules.list( + name_contains=rule_search, + ) + if rules: + print(f"Found rules: {[rule.name for rule in rules]}") + # Example of batch get if you just had the rule ids: + rules = client.rules.batch_get(rule_ids=[rule.rule_id for rule in rules]) + print(f"Batch get on IDs also works: {[rule.name for rule in rules]}") + + rule = rules[0] + print(f"Updating rule: {rule.name}") + rule = rule.update( + RuleUpdate( + description=f"Alert when velocity-to-voltage ratio exceeds 0.1 (Updated at {datetime.now().isoformat()})", + asset_ids=[asset_id], + ) + ) + updated = True + else: + print(f"No rules found for {rule_search}") + rules = client.rules.search( + asset_ids=[asset_id], + ) + if rules: + print(f"However these rules do exist: {[rule.name for rule in rules]}") + print("Attempting to create rule for high_velocity_voltage_ratio_alert") + rule = client.rules.create( + name="high_velocity_voltage_ratio_alert", + description="Alert when velocity-to-voltage ratio exceeds 0.1", + expression="$1 > 0.1", + channel_references=[ + ChannelReference( + channel_reference="$1", channel_identifier=calculated_channel.name + ), + ], + action=RuleAction.annotation( + annotation_type=RuleAnnotationType.DATA_REVIEW, + tags=["high_ratio", "alert"], + default_assignee_user_id=None, # You can set a user ID here if needed + ), + ) + print(f"Created rule: {rule.name} (ID: {rule.rule_id})") + + if updated: + print("Second run through, deleting rule") + rule.delete() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/lib/sift_client/resources/__init__.py b/python/lib/sift_client/resources/__init__.py index 80edb244a..7808366a0 100644 --- a/python/lib/sift_client/resources/__init__.py +++ b/python/lib/sift_client/resources/__init__.py @@ -1,21 +1,31 @@ from sift_client.resources.assets import AssetsAPIAsync from sift_client.resources.calculated_channels import CalculatedChannelsAPIAsync +from sift_client.resources.channels import ChannelsAPIAsync +from sift_client.resources.ingestion import IngestionAPIAsync from sift_client.resources.ping import PingAPIAsync +from sift_client.resources.rules import RulesAPIAsync from sift_client.resources.runs import RunsAPIAsync from sift_client.resources.sync_stubs import ( AssetsAPI, CalculatedChannelsAPI, + ChannelsAPI, PingAPI, + RulesAPI, RunsAPI, ) __all__ = [ "AssetsAPIAsync", "CalculatedChannelsAPIAsync", + "ChannelsAPIAsync", + "IngestionAPIAsync", "PingAPIAsync", + "RulesAPIAsync", "RunsAPIAsync", "AssetsAPI", "CalculatedChannelsAPI", + "ChannelsAPI", "PingAPI", + "RulesAPI", "RunsAPI", ] diff --git a/python/lib/sift_client/resources/assets.py b/python/lib/sift_client/resources/assets.py index 4a69c8a6c..dc7a35235 100644 --- a/python/lib/sift_client/resources/assets.py +++ b/python/lib/sift_client/resources/assets.py @@ -84,6 +84,7 @@ async def list_( created_by: Any | None = None, modified_by: Any | None = None, tags: list[str] | None = None, + tag_ids: list[str] | None = None, metadata: list[Any] | None = None, include_archived: bool = False, filter_query: str | None = None, @@ -94,6 +95,7 @@ async def list_( List assets with optional filtering. Args: + asset_ids: List of asset IDs to filter by. name: Exact name of the asset. name_contains: Partial name of the asset. name_regex: Regular expression string to filter assets by name. @@ -105,6 +107,7 @@ async def list_( created_by: Assets created by this user. modified_by: Assets last modified by this user. tags: Assets with these tags. + tag_ids: List of asset tag IDs to filter by. include_archived: Include archived assets. filter_query: Explicit CEL query to filter assets. order_by: How to order the retrieved assets. # TODO: tooling for this? @@ -137,7 +140,9 @@ async def list_( if modified_by: raise NotImplementedError if tags: - raise NotImplementedError + filters.append(cel_utils.in_("tag_name", tags)) + if tag_ids: + filters.append(cel_utils.in_("tag_ids", tag_ids)) if metadata: raise NotImplementedError if not include_archived: @@ -180,9 +185,9 @@ async def archive(self, asset: str | Asset, *, archive_runs: bool = False) -> As Returns: The archived Asset. """ - asset_id = asset.id if isinstance(asset, Asset) else asset + asset_id = asset.id_ or "" if isinstance(asset, Asset) else asset - await self._low_level_client.delete_asset(asset_id, archive_runs=archive_runs) + await self._low_level_client.delete_asset(asset_id or "", archive_runs=archive_runs) return await self.get(asset_id=asset_id) @@ -198,7 +203,7 @@ async def update(self, asset: str | Asset, update: AssetUpdate | dict) -> Asset: The updated Asset. """ - asset_id = asset.id if isinstance(asset, Asset) else asset + asset_id = asset.id_ or "" if isinstance(asset, Asset) else asset if isinstance(update, dict): update = AssetUpdate.model_validate(update) update.resource_id = asset_id diff --git a/python/lib/sift_client/resources/calculated_channels.py b/python/lib/sift_client/resources/calculated_channels.py index 474bdc762..b248936d3 100644 --- a/python/lib/sift_client/resources/calculated_channels.py +++ b/python/lib/sift_client/resources/calculated_channels.py @@ -269,7 +269,7 @@ async def update( The updated CalculatedChannel. """ calculated_channel_id = ( - calculated_channel.id + calculated_channel.id_ if isinstance(calculated_channel, CalculatedChannel) else calculated_channel ) diff --git a/python/lib/sift_client/resources/channels.py b/python/lib/sift_client/resources/channels.py new file mode 100644 index 000000000..e121bbf73 --- /dev/null +++ b/python/lib/sift_client/resources/channels.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import re +from datetime import datetime +from typing import TYPE_CHECKING, Dict, List + +import numpy as np +import pyarrow as pa + +from sift_client._internal.low_level_wrappers.channels import ChannelsLowLevelClient +from sift_client._internal.low_level_wrappers.data import DataLowLevelClient +from sift_client.resources._base import ResourceBase +from sift_client.types.channel import Channel +from sift_client.util import cel_utils as cel + +if TYPE_CHECKING: + from sift_client.client import SiftClient + + +class ChannelsAPIAsync(ResourceBase): + """ + High-level API for interacting with channels. + + This class provides a Pythonic, notebook-friendly interface for interacting with the ChannelsAPI. + It handles automatic handling of gRPC services, seamless type conversion, and clear error handling. + + All methods in this class use the Channel class from the low-level wrapper, which is a user-friendly + representation of a channel using standard Python data structures and types. + """ + + def __init__(self, sift_client: "SiftClient"): + """ + Initialize the ChannelsAPI. + + Args: + sift_client: The Sift client to use. + """ + super().__init__(sift_client) + self._low_level_client = ChannelsLowLevelClient(grpc_client=self.client.grpc_client) + self._data_low_level_client = DataLowLevelClient(grpc_client=self.client.grpc_client) + + async def get( + self, + *, + channel_id: str, + ) -> Channel: + """ + Get a Channel. + + Args: + channel_id: The ID of the channel. + + Returns: + The Channel. + """ + channel = await self._low_level_client.get_channel(channel_id=channel_id) + return self._apply_client_to_instance(channel) + + async def list( + self, + *, + asset_id: str | None = None, + name: str | None = None, + name_contains: str | None = None, + name_regex: str | re.Pattern | None = None, + description: str | None = None, + description_contains: str | None = None, + active: bool | None = None, + run_id: str | None = None, + run_name: str | None = None, + client_key: str | None = None, + created_before: datetime | None = None, + created_after: datetime | None = None, + modified_before: datetime | None = None, + modified_after: datetime | None = None, + order_by: str | None = None, + limit: int | None = None, + ) -> list[Channel]: + """ + List channels with optional filtering. + + Args: + asset_id: The asset ID to get. + name: The name of the channel to get. + name_contains: The partial name of the channel to get. + name_regex: The regex name of the channel to get. + description: The description of the channel to get. + description_contains: The partial description of the channel to get. + active: Whether the channel is active. + run_id: The run ID to get. + run_name: The name of the run to get. + client_key: The client key of the run to get. + created_before: The created date of the channel to get. + created_after: The created date of the channel to get. + modified_before: The modified date of the channel to get. + modified_after: The modified date of the channel to get. + order_by: How to order the retrieved channels. + limit: How many channels to retrieve. If None, retrieves all matches. + + Returns: + A list of Channels that matches the filter. + """ + if sum(bool(x) for x in [name, name_contains, name_regex]) > 1: + raise ValueError("Cannot provide more than one of name, name_contains, or name_regex") + if sum(bool(x) for x in [description, description_contains]) > 1: + raise ValueError("Cannot provide both description and description_contains") + if sum(bool(x) for x in [created_before, created_after]) > 1: + raise ValueError("Cannot provide both created_before and created_after") + if sum(bool(x) for x in [modified_before, modified_after]) > 1: + raise ValueError("Cannot provide both modified_before and modified_after") + + filter_parts = [] + if asset_id: + filter_parts.append(cel.equals("asset_id", asset_id)) + if name: + filter_parts.append(cel.equals("name", name)) + elif name_contains: + filter_parts.append(cel.contains("name", name_contains)) + elif name_regex: + if isinstance(name_regex, re.Pattern): + name_regex = name_regex.pattern + filter_parts.append(cel.match("name", name_regex)) # type: ignore + if description: + filter_parts.append(cel.equals("description", description)) + elif description_contains: + filter_parts.append(cel.contains("description", description_contains)) + if active: + filter_parts.append(cel.equals("active", active)) + if run_id: + filter_parts.append(cel.equals("run_id", run_id)) + if run_name: + filter_parts.append(cel.equals("run_name", run_name)) + if client_key: + filter_parts.append(cel.equals("client_key", client_key)) + if created_before: + filter_parts.append(cel.less_than("created_date", created_before)) + if created_after: + filter_parts.append(cel.greater_than("created_date", created_after)) + if modified_before: + filter_parts.append(cel.less_than("modified_date", modified_before)) + if modified_after: + filter_parts.append(cel.greater_than("modified_date", modified_after)) + + filter_str = " && ".join(filter_parts) + + channels = await self._low_level_client.list_all_channels( + query_filter=filter_str, + order_by=order_by, + max_results=limit, + ) + return self._apply_client_to_instances(channels) + + async def find(self, **kwargs) -> Channel | None: + """ + Find a single channel matching the given query. Takes the same arguments as `list`. If more than one channel is found, + raises an error. + + Args: + **kwargs: Keyword arguments to pass to `list`. + + Returns: + The Channel found or None. + """ + channels = await self.list(**kwargs) + if len(channels) > 1: + raise ValueError("Multiple channels found for query") + elif len(channels) == 1: + return channels[0] + return None + + async def get_data( + self, + *, + channels: List[Channel], + run_id: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + limit: int | None = None, + ) -> Dict[str, np.ndarray]: + """ + Get data for one or more channels. + + Args: + channels: The channels to get data for. + run_id: The run to get data for. + start_time: The start time to get data for. + end_time: The end time to get data for. + limit: The maximum number of data points to return. Will be in increments of page_size or default page size defined by the call if no page_size is provided. + """ + return await self._data_low_level_client.get_channel_data( + channels=channels, + run_id=run_id, + start_time=start_time, + end_time=end_time, + limit=limit, + ) + + async def get_data_as_arrow( + self, + *, + channels: List[Channel], + run_id: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + limit: int | None = None, + ) -> Dict[str, pa.Table]: + """ + Get data for one or more channels as pyarrow tables. + """ + data = await self.get_data( + channels=channels, + run_id=run_id, + start_time=start_time, + end_time=end_time, + limit=limit, + ) + return {k: pa.Table.from_pandas(v) for k, v in data.items()} diff --git a/python/lib/sift_client/resources/ingestion.py b/python/lib/sift_client/resources/ingestion.py new file mode 100644 index 000000000..9830b3f02 --- /dev/null +++ b/python/lib/sift_client/resources/ingestion.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import logging +from datetime import datetime +from typing import TYPE_CHECKING, Any, List + +from sift_client._internal.low_level_wrappers.ingestion import IngestionLowLevelClient +from sift_client.resources._base import ResourceBase +from sift_client.types.ingestion import Flow + +if TYPE_CHECKING: + from sift_client.client import SiftClient + +logger = logging.getLogger(__name__) + + +class IngestionAPIAsync(ResourceBase): + """ + High-level API for interacting with ingestion services. + + This class provides a Pythonic, notebook-friendly interface for interacting with the IngestionAPI. + It handles automatic handling of gRPC services, seamless type conversion, and clear error handling. + + All methods in this class use the Flow class from the types module, which is a user-friendly + representation of ingestion flows using standard Python data structures and types. + """ + + def __init__(self, sift_client: "SiftClient"): + """ + Initialize the IngestionAPI. + + Args: + sift_client: The Sift client to use. + """ + super().__init__(sift_client) + self._low_level_client = IngestionLowLevelClient(grpc_client=self.client.grpc_client) + + async def create_ingestion_config( + self, + *, + asset_name: str, + run_id: str | None = None, + flows: List[Flow], + client_key: str | None = None, + organization_id: str | None = None, + ) -> str: + """ + Create an ingestion config. + + Args: + asset_name: The name of the asset for this ingestion config. + run_id: Optionally provide a run ID to create a run for the given asset. + flows: List of flow configurations. + client_key: Optional client key for identifying this config. + organization_id: The organization ID. + + Returns: + The ingestion config ID. + + Raises: + ValueError: If asset_name is not provided or flows is empty. + """ + if not asset_name: + raise ValueError("asset_name must be provided") + if not flows: + raise ValueError("flows must not be empty") + + ingestion_config_id = await self._low_level_client.create_ingestion_config( + asset_name=asset_name, + flows=flows, + client_key=client_key, + organization_id=organization_id, + ) + for flow in flows: + flow._apply_client_to_instance(self.client) + if run_id: + flow.run_id = run_id + + return ingestion_config_id + + def ingest( + self, + *, + flow: Flow, + timestamp: datetime, + channel_values: dict[str, Any], + ): + self._low_level_client.ingest_flow( + flow=flow, + timestamp=timestamp, + channel_values=channel_values, + ) + + def wait_for_ingestion_to_complete(self, timeout: float | None = None): + """ + Wait for all ingestion to complete. + + Args: + run_id: The id of the run to wait for. + timeout: The timeout in seconds to wait for ingestion to complete. If None, will wait forever. + """ + logger.info("Waiting for ingestion to complete") + self._low_level_client.wait_for_ingestion_to_complete(timeout) diff --git a/python/lib/sift_client/resources/rules.py b/python/lib/sift_client/resources/rules.py new file mode 100644 index 000000000..6f4ab0132 --- /dev/null +++ b/python/lib/sift_client/resources/rules.py @@ -0,0 +1,271 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, List + +from sift_client._internal.low_level_wrappers.rules import RulesLowLevelClient +from sift_client.resources._base import ResourceBase +from sift_client.types.channel import ChannelReference +from sift_client.types.rule import Rule, RuleAction, RuleUpdate +from sift_client.util import cel_utils as cel + +if TYPE_CHECKING: + from sift_client.client import SiftClient + + +class RulesAPIAsync(ResourceBase): + """ + High-level API for interacting with rules. + + This class provides a Pythonic, notebook-friendly interface for interacting with the RulesAPI. + It handles automatic handling of gRPC services, seamless type conversion, and clear error handling. + + All methods in this class use the Rule class from the low-level wrapper, which is a user-friendly + representation of a rule using standard Python data structures and types. + """ + + def __init__(self, sift_client: "SiftClient"): + """ + Initialize the RulesAPI. + + Args: + sift_client: The Sift client to use. + """ + super().__init__(sift_client) + self._low_level_client = RulesLowLevelClient(grpc_client=self.client.grpc_client) + + async def get( + self, + *, + rule_id: str | None = None, + client_key: str | None = None, + ) -> Rule: + """ + Get a Rule. + + Args: + rule_id: The ID of the rule. + client_key: The client key of the rule. + + Returns: + The Rule. + """ + rule = await self._low_level_client.get_rule(rule_id=rule_id, client_key=client_key) + return self._apply_client_to_instance(rule) + + async def list( + self, + *, + name: str | None = None, + name_contains: str | None = None, + name_regex: str | re.Pattern | None = None, + order_by: str | None = None, + limit: int | None = None, + include_deleted: bool = False, + ) -> list[Rule]: + """ + List rules with optional filtering. + + Args: + name: Exact name of the rule. + name_contains: Partial name of the rule. + name_regex: Regular expression string to filter rules by name. + order_by: How to order the retrieved rules. + limit: How many rules to retrieve. If None, retrieves all matches. + + Returns: + A list of Rules that matches the filter. + """ + if int(name is not None) + int(name_contains is not None) + int(name_regex is not None) > 1: + raise ValueError("Must use EITHER name, name_contains, or name_regex, not multiple") + + filters = [] + if name: + filters.append(cel.equals("name", name)) + if name_contains: + filters.append(cel.contains("name", name_contains)) + if name_regex: + filters.append(cel.match("name", name_regex)) + if not include_deleted: + filters.append(cel.equals_null("deleted_date")) + filter_str = " && ".join(filters) if filters else "" + rules = await self._low_level_client.list_all_rules( + filter_query=filter_str, + order_by=order_by, + max_results=limit, + page_size=limit, + ) + return self._apply_client_to_instances(rules) + + async def find(self, **kwargs) -> Rule | None: + """ + Find a single rule matching the given query. Takes the same arguments as `list`. If more than one rule is found, + raises an error. + + Args: + **kwargs: Keyword arguments to pass to `list`. + + Returns: + The Rule found or None. + """ + rules = await self.list(**kwargs) + if len(rules) > 1: + raise ValueError("Multiple rules found for query") + elif len(rules) == 1: + return rules[0] + return None + + async def create( + self, + name: str, + description: str, + expression: str, + channel_references: List[ChannelReference], + action: RuleAction, + organization_id: str | None = None, + client_key: str | None = None, + asset_ids: List[str] | None = None, + contextual_channels: List[str] | None = None, + is_external: bool = False, + ) -> Rule: + """ + Create a new rule. + """ + created_rule = await self._low_level_client.create_rule( + name=name, + description=description, + organization_id=organization_id, + expression=expression, + action=action, + channel_references=channel_references, + client_key=client_key, + asset_ids=asset_ids, + contextual_channels=contextual_channels, + is_external=is_external, + ) + return self._apply_client_to_instance(created_rule) + + async def update( + self, rule: str | Rule, update: RuleUpdate | dict, version_notes: str | None = None + ) -> Rule: + """ + Update a Rule. + + Args: + rule: The Rule or rule ID to update. + update: Updates to apply to the Rule. + version_notes: Notes to include in the rule version. + Returns: + The updated Rule. + """ + if isinstance(rule, str): + rule = await self.get(rule_id=rule) + + if isinstance(update, dict): + update = RuleUpdate.model_validate(update) + + updated_rule = await self._low_level_client.update_rule(rule, update, version_notes) + return self._apply_client_to_instance(updated_rule) + + async def archive( + self, + *, + rule: str | Rule | None = None, + rules: List[Rule] | None = None, + rule_ids: List[str] | None = None, + client_keys: List[str] | None = None, + ) -> None: + """ + Archive a rule or multiple. + + Args: + rule: The Rule to archive. + rules: The Rules to archive. + rule_ids: The rule IDs to archive. + client_keys: The client keys to archive. + """ + if rule: + if isinstance(rule, Rule): + await self._low_level_client.archive_rule(rule_id=rule.id_) + else: + await self._low_level_client.archive_rule(rule_id=rule) + elif rules: + if len(rules) == 1: + await self._low_level_client.archive_rule(rule_id=rules[0].id_) + else: + await self._low_level_client.batch_archive_rules( + rule_ids=[r.id_ for r in rules], # type: ignore + ) + elif rule_ids: + if len(rule_ids) == 1: + await self._low_level_client.archive_rule(rule_id=rule_ids[0]) + else: + await self._low_level_client.batch_archive_rules(rule_ids=rule_ids) + elif client_keys: + await self._low_level_client.batch_archive_rules(client_keys=client_keys) + else: + raise ValueError("Either rules, rule_ids, or client_keys must be provided") + + async def restore( + self, + *, + rule: str | Rule, + rule_id: str | None = None, + client_key: str | None = None, + ) -> Rule: + """ + Restore a rule. + + Args: + rule: The Rule or rule ID to restore. + rule_id: The rule ID to restore (alternative to rule parameter). + client_key: The client key to restore (alternative to rule parameter). + + Returns: + The restored Rule. + """ + if rule_id or client_key: + restored_rule = await self._low_level_client.restore_rule( + rule_id=rule_id, client_key=client_key + ) + else: + rule_id = rule.id_ if isinstance(rule, Rule) else rule + restored_rule = await self._low_level_client.restore_rule(rule_id=rule_id) + + return self._apply_client_to_instance(restored_rule) + + async def batch_restore( + self, + *, + rule_ids: List[str] | None = None, + client_keys: List[str] | None = None, + ) -> None: + """ + Batch restore rules. + + Args: + rule_ids: List of rule IDs to restore. + client_keys: List of client keys to undelete. + """ + await self._low_level_client.batch_restore_rules(rule_ids=rule_ids, client_keys=client_keys) + + async def batch_get( + self, + *, + rule_ids: List[str] | None = None, + client_keys: List[str] | None = None, + ) -> List[Rule]: + """ + Get multiple rules by rule IDs or client keys. + + Args: + rule_ids: List of rule IDs to get. + client_keys: List of client keys to get. + + Returns: + List of Rules. + """ + rules = await self._low_level_client.batch_get_rules( + rule_ids=rule_ids, client_keys=client_keys + ) + return self._apply_client_to_instances(rules) diff --git a/python/lib/sift_client/resources/runs.py b/python/lib/sift_client/resources/runs.py index bb743e192..7ce01ac6b 100644 --- a/python/lib/sift_client/resources/runs.py +++ b/python/lib/sift_client/resources/runs.py @@ -212,25 +212,25 @@ async def update(self, run: str | Run, update: RunUpdate | dict) -> Run: if isinstance(update, dict): update = RunUpdate.model_validate(update) - update.resource_id = run.id + update.resource_id = run.id_ updated_run = await self._low_level_client.update_run(run, update) return self._apply_client_to_instance(updated_run) - async def delete( + async def archive( self, *, run: str | Run, ) -> None: """ - Delete a run. + Archive a run. Args: - run: The Run or run ID to delete. + run: The Run or run ID to archive. """ - run_id = run.id if isinstance(run, Run) else run + run_id = run.id_ if isinstance(run, Run) else run if not isinstance(run_id, str): raise TypeError(f"run_id must be a string not {type(run_id)}") - await self._low_level_client.delete_run(run_id=run_id) + await self._low_level_client.archive_run(run_id=run_id) async def stop( self, @@ -243,8 +243,8 @@ async def stop( Args: run: The Run or run ID to stop. """ - run_id = run.id if isinstance(run, Run) else run - await self._low_level_client.stop_run(run_id=run_id) + run_id = run.id_ if isinstance(run, Run) else run + await self._low_level_client.stop_run(run_id=run_id or "") async def create_automatic_association_for_assets( self, @@ -258,7 +258,7 @@ async def create_automatic_association_for_assets( run: The Run or run ID. asset_names: List of asset names to associate. """ - run_id = run.id if isinstance(run, Run) else run + run_id = run.id_ or "" if isinstance(run, Run) else run await self._low_level_client.create_automatic_run_association_for_assets( run_id=run_id, asset_names=asset_names ) @@ -270,5 +270,5 @@ async def stop_run(self, run: str | Run) -> None: Args: run: The Run or run ID to stop. """ - run_id = run.id if isinstance(run, Run) else run - await self._low_level_client.stop_run(run_id=run_id) + run_id = run.id_ or "" if isinstance(run, Run) else run + await self._low_level_client.stop_run(run_id=run_id or "") diff --git a/python/lib/sift_client/resources/sync_stubs/__init__.py b/python/lib/sift_client/resources/sync_stubs/__init__.py index a33bb267c..a18c1db31 100644 --- a/python/lib/sift_client/resources/sync_stubs/__init__.py +++ b/python/lib/sift_client/resources/sync_stubs/__init__.py @@ -7,13 +7,17 @@ from sift_client.resources import ( AssetsAPIAsync, CalculatedChannelsAPIAsync, + ChannelsAPIAsync, PingAPIAsync, + RulesAPIAsync, RunsAPIAsync, ) PingAPI = generate_sync_api(PingAPIAsync, "PingAPI") AssetsAPI = generate_sync_api(AssetsAPIAsync, "AssetsAPI") CalculatedChannelsAPI = generate_sync_api(CalculatedChannelsAPIAsync, "CalculatedChannelsAPI") +ChannelsAPI = generate_sync_api(ChannelsAPIAsync, "ChannelsAPI") +RulesAPI = generate_sync_api(RulesAPIAsync, "RulesAPI") RunsAPI = generate_sync_api(RunsAPIAsync, "RunsAPI") __all__ = ["PingAPI", "AssetsAPI", "CalculatedChannelsAPI", "RunsAPI"] diff --git a/python/lib/sift_client/resources/sync_stubs/__init__.pyi b/python/lib/sift_client/resources/sync_stubs/__init__.pyi index f978b2b2b..ca9c37336 100644 --- a/python/lib/sift_client/resources/sync_stubs/__init__.pyi +++ b/python/lib/sift_client/resources/sync_stubs/__init__.pyi @@ -4,12 +4,16 @@ from __future__ import annotations import re from datetime import datetime -from typing import Any, List +from typing import Any, Dict, List + +import numpy as np +import pyarrow as pa from sift_client.client import SiftClient from sift_client.types.asset import Asset, AssetUpdate from sift_client.types.calculated_channel import CalculatedChannel, CalculatedChannelUpdate -from sift_client.types.channel import ChannelReference +from sift_client.types.channel import Channel, ChannelReference +from sift_client.types.rule import Rule, RuleAction, RuleUpdate from sift_client.types.run import Run, RunUpdate class AssetsAPI: @@ -92,6 +96,7 @@ class AssetsAPI: created_by: Any | None = None, modified_by: Any | None = None, tags: list[str] | None = None, + tag_ids: list[str] | None = None, metadata: list[Any] | None = None, include_archived: bool = False, filter_query: str | None = None, @@ -102,6 +107,7 @@ class AssetsAPI: List assets with optional filtering. Args: + asset_ids: List of asset IDs to filter by. name: Exact name of the asset. name_contains: Partial name of the asset. name_regex: Regular expression string to filter assets by name. @@ -113,6 +119,7 @@ class AssetsAPI: created_by: Assets created by this user. modified_by: Assets last modified by this user. tags: Assets with these tags. + tag_ids: List of asset tag IDs to filter by. include_archived: Include archived assets. filter_query: Explicit CEL query to filter assets. order_by: How to order the retrieved assets. # TODO: tooling for this? @@ -361,6 +368,139 @@ class CalculatedChannelsAPI: """ ... +class ChannelsAPI: + """ + Sync counterpart to `ChannelsAPIAsync`. + + + High-level API for interacting with channels. + + This class provides a Pythonic, notebook-friendly interface for interacting with the ChannelsAPI. + It handles automatic handling of gRPC services, seamless type conversion, and clear error handling. + + All methods in this class use the Channel class from the low-level wrapper, which is a user-friendly + representation of a channel using standard Python data structures and types. + """ + + def __init__(self, sift_client: "SiftClient"): + """ + Initialize the ChannelsAPI. + + Args: + sift_client: The Sift client to use. + """ + ... + + def _run(self, coro): + """ """ + ... + + def find(self, **kwargs) -> Channel | None: + """ + Find a single channel matching the given query. Takes the same arguments as `list`. If more than one channel is found, + raises an error. + + Args: + **kwargs: Keyword arguments to pass to `list`. + + Returns: + The Channel found or None. + """ + ... + + def get(self, *, channel_id: str) -> Channel: + """ + Get a Channel. + + Args: + channel_id: The ID of the channel. + + Returns: + The Channel. + """ + ... + + def get_data( + self, + *, + channels: List[Channel], + run_id: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + limit: int | None = None, + ) -> Dict[str, np.ndarray]: + """ + Get data for one or more channels. + + Args: + channels: The channels to get data for. + run_id: The run to get data for. + start_time: The start time to get data for. + end_time: The end time to get data for. + limit: The maximum number of data points to return. Will be in increments of page_size or default page size defined by the call if no page_size is provided. + """ + ... + + def get_data_as_arrow( + self, + *, + channels: List[Channel], + run_id: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + limit: int | None = None, + ) -> Dict[str, pa.Table]: + """ + Get data for one or more channels as pyarrow tables. + """ + ... + + def list( + self, + *, + asset_id: str | None = None, + name: str | None = None, + name_contains: str | None = None, + name_regex: str | re.Pattern | None = None, + description: str | None = None, + description_contains: str | None = None, + active: bool | None = None, + run_id: str | None = None, + run_name: str | None = None, + client_key: str | None = None, + created_before: datetime | None = None, + created_after: datetime | None = None, + modified_before: datetime | None = None, + modified_after: datetime | None = None, + order_by: str | None = None, + limit: int | None = None, + ) -> list[Channel]: + """ + List channels with optional filtering. + + Args: + asset_id: The asset ID to get. + name: The name of the channel to get. + name_contains: The partial name of the channel to get. + name_regex: The regex name of the channel to get. + description: The description of the channel to get. + description_contains: The partial description of the channel to get. + active: Whether the channel is active. + run_id: The run ID to get. + run_name: The name of the run to get. + client_key: The client key of the run to get. + created_before: The created date of the channel to get. + created_after: The created date of the channel to get. + modified_before: The modified date of the channel to get. + modified_after: The modified date of the channel to get. + order_by: How to order the retrieved channels. + limit: How many channels to retrieve. If None, retrieves all matches. + + Returns: + A list of Channels that matches the filter. + """ + ... + class PingAPI: """ Sync counterpart to `PingAPIAsync`. @@ -391,6 +531,179 @@ class PingAPI: """ ... +class RulesAPI: + """ + Sync counterpart to `RulesAPIAsync`. + + + High-level API for interacting with rules. + + This class provides a Pythonic, notebook-friendly interface for interacting with the RulesAPI. + It handles automatic handling of gRPC services, seamless type conversion, and clear error handling. + + All methods in this class use the Rule class from the low-level wrapper, which is a user-friendly + representation of a rule using standard Python data structures and types. + """ + + def __init__(self, sift_client: "SiftClient"): + """ + Initialize the RulesAPI. + + Args: + sift_client: The Sift client to use. + """ + ... + + def _run(self, coro): + """ """ + ... + + def archive( + self, + *, + rule: str | Rule | None = None, + rules: List[Rule] | None = None, + rule_ids: List[str] | None = None, + client_keys: List[str] | None = None, + ) -> None: + """ + Archive a rule or multiple. + + Args: + rule: The Rule to archive. + rules: The Rules to archive. + rule_ids: The rule IDs to archive. + client_keys: The client keys to archive. + """ + ... + + def batch_get( + self, *, rule_ids: List[str] | None = None, client_keys: List[str] | None = None + ) -> List[Rule]: + """ + Get multiple rules by rule IDs or client keys. + + Args: + rule_ids: List of rule IDs to get. + client_keys: List of client keys to get. + + Returns: + List of Rules. + """ + ... + + def batch_restore( + self, *, rule_ids: List[str] | None = None, client_keys: List[str] | None = None + ) -> None: + """ + Batch restore rules. + + Args: + rule_ids: List of rule IDs to restore. + client_keys: List of client keys to undelete. + """ + ... + + def create( + self, + name: str, + description: str, + expression: str, + channel_references: List[ChannelReference], + action: RuleAction, + organization_id: str | None = None, + client_key: str | None = None, + asset_ids: List[str] | None = None, + contextual_channels: List[str] | None = None, + is_external: bool = False, + ) -> Rule: + """ + Create a new rule. + """ + ... + + def find(self, **kwargs) -> Rule | None: + """ + Find a single rule matching the given query. Takes the same arguments as `list`. If more than one rule is found, + raises an error. + + Args: + **kwargs: Keyword arguments to pass to `list`. + + Returns: + The Rule found or None. + """ + ... + + def get(self, *, rule_id: str | None = None, client_key: str | None = None) -> Rule: + """ + Get a Rule. + + Args: + rule_id: The ID of the rule. + client_key: The client key of the rule. + + Returns: + The Rule. + """ + ... + + def list( + self, + *, + name: str | None = None, + name_contains: str | None = None, + name_regex: str | re.Pattern | None = None, + order_by: str | None = None, + limit: int | None = None, + include_deleted: bool = False, + ) -> list[Rule]: + """ + List rules with optional filtering. + + Args: + name: Exact name of the rule. + name_contains: Partial name of the rule. + name_regex: Regular expression string to filter rules by name. + order_by: How to order the retrieved rules. + limit: How many rules to retrieve. If None, retrieves all matches. + + Returns: + A list of Rules that matches the filter. + """ + ... + + def restore( + self, *, rule: str | Rule, rule_id: str | None = None, client_key: str | None = None + ) -> Rule: + """ + Restore a rule. + + Args: + rule: The Rule or rule ID to restore. + rule_id: The rule ID to restore (alternative to rule parameter). + client_key: The client key to restore (alternative to rule parameter). + + Returns: + The restored Rule. + """ + ... + + def update( + self, rule: str | Rule, update: RuleUpdate | dict, version_notes: str | None = None + ) -> Rule: + """ + Update a Rule. + + Args: + rule: The Rule or rule ID to update. + update: Updates to apply to the Rule. + version_notes: Notes to include in the rule version. + Returns: + The updated Rule. + """ + ... + class RunsAPI: """ Sync counterpart to `RunsAPIAsync`. @@ -418,6 +731,15 @@ class RunsAPI: """ """ ... + def archive(self, *, run: str | Run) -> None: + """ + Archive a run. + + Args: + run: The Run or run ID to archive. + """ + ... + def create( self, name: str, @@ -459,15 +781,6 @@ class RunsAPI: """ ... - def delete(self, *, run: str | Run) -> None: - """ - Delete a run. - - Args: - run: The Run or run ID to delete. - """ - ... - def find(self, **kwargs) -> Run | None: """ Find a single run matching the given query. Takes the same arguments as `list`. If more than one run is found, diff --git a/python/lib/sift_client/types/__init__.py b/python/lib/sift_client/types/__init__.py index 23a640c31..6698dd0cb 100644 --- a/python/lib/sift_client/types/__init__.py +++ b/python/lib/sift_client/types/__init__.py @@ -4,13 +4,38 @@ CalculatedChannelUpdate, ) from sift_client.types.channel import ( + Channel, + ChannelBitFieldElement, + ChannelDataType, ChannelReference, ) +from sift_client.types.ingestion import IngestionConfig +from sift_client.types.rule import ( + Rule, + RuleAction, + RuleActionType, + RuleAnnotationType, + RuleUpdate, + RuleVersion, +) +from sift_client.types.run import Run, RunUpdate __all__ = [ "Asset", "AssetUpdate", "CalculatedChannel", "CalculatedChannelUpdate", + "Rule", + "RuleUpdate", + "RuleAction", + "RuleVersion", + "RuleActionType", + "RuleAnnotationType", + "Channel", + "ChannelBitFieldElement", + "ChannelDataType", "ChannelReference", + "Run", + "RunUpdate", + "IngestionConfig", ] diff --git a/python/lib/sift_client/types/_base.py b/python/lib/sift_client/types/_base.py index 5420acca0..ab19ff278 100644 --- a/python/lib/sift_client/types/_base.py +++ b/python/lib/sift_client/types/_base.py @@ -17,6 +17,7 @@ class BaseType(BaseModel, Generic[ProtoT, SelfT], ABC): model_config = ConfigDict(frozen=True) + id_: str | None = None _client: SiftClient | None = None @property diff --git a/python/lib/sift_client/types/asset.py b/python/lib/sift_client/types/asset.py index e063430fc..bfcd68126 100644 --- a/python/lib/sift_client/types/asset.py +++ b/python/lib/sift_client/types/asset.py @@ -1,11 +1,13 @@ from __future__ import annotations from datetime import datetime -from typing import TYPE_CHECKING, Type +from typing import TYPE_CHECKING, List, Type from sift.assets.v1.assets_pb2 import Asset as AssetProto from sift_client.types._base import BaseType, MappingHelper, ModelUpdate +from sift_client.types.channel import Channel +from sift_client.types.run import Run from sift_client.util.metadata import metadata_dict_to_proto, metadata_proto_to_dict if TYPE_CHECKING: @@ -17,7 +19,6 @@ class Asset(BaseType[AssetProto, "Asset"]): Model of the Sift Asset. """ - id: str name: str organization_id: str created_date: datetime @@ -42,8 +43,15 @@ def created_by(self): def modified_by(self): raise NotImplementedError - def runs(self, limit: int | None = None): - return self.client.runs.list(asset_id=self.id, limit=limit) + @property + def runs(self) -> List[Run]: + return self.client.runs.list(asset_id=self.id_) + + def channels(self, run_id: str | None = None, limit: int | None = None) -> List[Channel]: + """ + Return all channels for this asset. + """ + return self.client.channels.list(asset_id=self.id_, run_id=run_id, limit=limit) @property def rules(self): @@ -78,7 +86,7 @@ def update(self, update: AssetUpdate | dict) -> Asset: @classmethod def _from_proto(cls, proto: AssetProto, sift_client: SiftClient | None = None) -> Asset: return cls( - id=proto.asset_id, + id_=proto.asset_id, name=proto.name, organization_id=proto.organization_id, created_date=proto.created_date.ToDatetime(), diff --git a/python/lib/sift_client/types/calculated_channel.py b/python/lib/sift_client/types/calculated_channel.py index 601fcd6a9..3f64046cd 100644 --- a/python/lib/sift_client/types/calculated_channel.py +++ b/python/lib/sift_client/types/calculated_channel.py @@ -22,7 +22,6 @@ class CalculatedChannel(BaseType[CalculatedChannelProto, "CalculatedChannel"]): Model of the Sift Calculated Channel. """ - id: str name: str description: str expression: str @@ -88,7 +87,7 @@ def _from_proto( cls, proto: CalculatedChannelProto, sift_client: SiftClient | None = None ) -> CalculatedChannel: return cls( - id=proto.calculated_channel_id, + id_=proto.calculated_channel_id, name=proto.name, description=proto.description, expression=proto.calculated_channel_configuration.query_configuration.sel.expression, diff --git a/python/lib/sift_client/types/channel.py b/python/lib/sift_client/types/channel.py index 145e5a426..b8ce8f160 100644 --- a/python/lib/sift_client/types/channel.py +++ b/python/lib/sift_client/types/channel.py @@ -1,6 +1,281 @@ from __future__ import annotations +from datetime import datetime +from enum import Enum +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +import sift.common.type.v1.channel_data_type_pb2 as channel_pb from pydantic import BaseModel +from sift.channels.v3.channels_pb2 import Channel as ChannelProto +from sift.common.type.v1.channel_bit_field_element_pb2 import ( + ChannelBitFieldElement as ChannelBitFieldElementPb, +) +from sift.common.type.v1.channel_enum_type_pb2 import ChannelEnumType as ChannelEnumTypePb +from sift.data.v2.data_pb2 import ( + BitFieldValues, + BoolValues, + DoubleValues, + EnumValues, + FloatValues, + Int32Values, + Int64Values, + StringValues, + Uint32Values, + Uint64Values, +) +from sift.ingestion_configs.v2.ingestion_configs_pb2 import ChannelConfig + +from sift_client.types._base import BaseType +from sift_client.types.run import Run + +if TYPE_CHECKING: + from sift_client.client import SiftClient + from sift_client.types.asset import Asset + + +# Enum for channel data types (mimics protobuf values, but as int for now) +class ChannelDataType(Enum): + DOUBLE = channel_pb.CHANNEL_DATA_TYPE_DOUBLE + STRING = channel_pb.CHANNEL_DATA_TYPE_STRING + ENUM = channel_pb.CHANNEL_DATA_TYPE_ENUM + BIT_FIELD = channel_pb.CHANNEL_DATA_TYPE_BIT_FIELD + BOOL = channel_pb.CHANNEL_DATA_TYPE_BOOL + FLOAT = channel_pb.CHANNEL_DATA_TYPE_FLOAT + INT_32 = channel_pb.CHANNEL_DATA_TYPE_INT_32 + INT_64 = channel_pb.CHANNEL_DATA_TYPE_INT_64 + UINT_32 = channel_pb.CHANNEL_DATA_TYPE_UINT_32 + UINT_64 = channel_pb.CHANNEL_DATA_TYPE_UINT_64 + + def __str__(self) -> str: + ret = self.name.lower() + if "int" in ret: + ret = ret.replace("int_", "int") + return ret + + @staticmethod + def from_api_format(val: str) -> Optional["ChannelDataType"]: + for item in ChannelDataType: + if "CHANNEL_DATA_TYPE_" + item.name == val: + return item + return None + + @staticmethod + def from_str(raw: str) -> Optional["ChannelDataType"]: + if raw.startswith("CHANNEL_DATA_TYPE_"): + val = ChannelDataType.from_api_format(raw) + if val is None: + return None + for item in ChannelDataType: + if item.name == val.name: + return item + raise Exception( + "Unreachable. ChannelDataTypeStrRep and ChannelDataType enum names are out of sync." + ) + elif raw.startswith("sift.data"): + for item in ChannelDataType: + val = raw.split(".")[-1].lower().replace("values", "") # type: ignore + val = "bit_field" if val == "bitfield" else val # type: ignore + if item.__str__() == val: + return item + raise Exception( + "Unreachable. ChannelTypeUrls and ChannelDataType enum names are out of sync." + ) + else: + try: + for item in ChannelDataType: + if item.__str__() == raw.lower(): + return item + except ValueError: + return None + raise Exception(f"Unknown channel data type: {raw}") + + @staticmethod + def proto_data_class(data_type: ChannelDataType) -> Any: + if data_type == ChannelDataType.DOUBLE: + return DoubleValues + elif data_type == ChannelDataType.FLOAT: + return FloatValues + elif data_type == ChannelDataType.STRING: + return StringValues + elif data_type == ChannelDataType.ENUM: + return EnumValues + elif data_type == ChannelDataType.BIT_FIELD: + return BitFieldValues + elif data_type == ChannelDataType.BOOL: + return BoolValues + elif data_type == ChannelDataType.INT_32: + return Int32Values + elif data_type == ChannelDataType.INT_64: + return Int64Values + elif data_type == ChannelDataType.UINT_32: + return Uint32Values + elif data_type == ChannelDataType.UINT_64: + return Uint64Values + else: + raise ValueError(f"Unknown data type: {data_type}") + + # TODO: Can we get rid of this? Is hashing the same between clients that likely to ever actually discover a conflict? + def hash_str(self, api_format: bool = False) -> str: + if self == ChannelDataType.DOUBLE: + return "CHANNEL_DATA_TYPE_DOUBLE" if api_format else ChannelDataType.DOUBLE.__str__() + elif self == ChannelDataType.STRING: + return "CHANNEL_DATA_TYPE_STRING" if api_format else ChannelDataType.STRING.__str__() + elif self == ChannelDataType.ENUM: + return "CHANNEL_DATA_TYPE_ENUM" if api_format else ChannelDataType.ENUM.__str__() + elif self == ChannelDataType.BIT_FIELD: + return ( + "CHANNEL_DATA_TYPE_BIT_FIELD" if api_format else ChannelDataType.BIT_FIELD.__str__() + ) + elif self == ChannelDataType.BOOL: + return "CHANNEL_DATA_TYPE_BOOL" if api_format else ChannelDataType.BOOL.__str__() + elif self == ChannelDataType.FLOAT: + return "CHANNEL_DATA_TYPE_FLOAT" if api_format else ChannelDataType.FLOAT.__str__() + elif self == ChannelDataType.INT_32: + return "CHANNEL_DATA_TYPE_INT_32" if api_format else ChannelDataType.INT_32.__str__() + elif self == ChannelDataType.INT_64: + return "CHANNEL_DATA_TYPE_INT_64" if api_format else ChannelDataType.INT_64.__str__() + elif self == ChannelDataType.UINT_32: + return "CHANNEL_DATA_TYPE_UINT_32" if api_format else ChannelDataType.UINT_32.__str__() + elif self == ChannelDataType.UINT_64: + return "CHANNEL_DATA_TYPE_UINT_64" if api_format else ChannelDataType.UINT_64.__str__() + else: + raise Exception("Unreachable.") + + +# Bit field element model +class ChannelBitFieldElement(BaseModel): + name: str + index: int + bit_count: int + + @classmethod + def _from_proto(cls, message: ChannelBitFieldElementPb) -> ChannelBitFieldElement: + return cls( + name=message.name, + index=message.index, + bit_count=message.bit_count, + ) + + def _to_proto(self) -> ChannelBitFieldElementPb: + return ChannelBitFieldElementPb( + name=self.name, + index=self.index, + bit_count=self.bit_count, + ) + + +# Channel config model +class Channel(BaseType[ChannelProto, "Channel"]): + name: str + data_type: ChannelDataType + description: str | None = None + unit: str | None = None + bit_field_elements: List[ChannelBitFieldElement] | None = None + enum_types: Dict[str, int] = {} + asset_id: str | None = None + created_date: datetime | None = None + modified_date: datetime | None = None + created_by_user_id: str | None = None + modified_by_user_id: str | None = None + + @staticmethod + def _enum_types_to_proto_list(enum_types: Dict[str, int]) -> List[ChannelEnumTypePb]: + """Convert a dictionary of enum types to a list of ChannelEnumTypePb objects.""" + return [ChannelEnumTypePb(name=name, key=key) for name, key in enum_types.items()] + + @staticmethod + def _enum_types_from_proto_list(enum_types: List[ChannelEnumTypePb]) -> Dict[str, int]: + """Convert a list of ChannelEnumTypePb objects to a dictionary of enum types.""" + return {enum.name: enum.key for enum in enum_types} + + @classmethod + def _from_proto( + cls, proto: ChannelProto | ChannelConfig, sift_client: SiftClient | None = None + ) -> Channel: + if isinstance(proto, ChannelProto): + return cls( + id_=proto.channel_id, + name=proto.name, + data_type=ChannelDataType(proto.data_type), + description=proto.description, + unit=proto.unit_id, + bit_field_elements=[ + ChannelBitFieldElement._from_proto(el) for el in proto.bit_field_elements + ], + enum_types=cls._enum_types_from_proto_list(proto.enum_types), # type: ignore + asset_id=proto.asset_id, + created_date=proto.created_date.ToDatetime(), + modified_date=proto.modified_date.ToDatetime(), + created_by_user_id=proto.created_by_user_id, + modified_by_user_id=proto.modified_by_user_id, + _client=sift_client, + ) + elif isinstance(proto, ChannelConfig): + return cls( + id_=proto.name, + name=proto.name, + data_type=ChannelDataType(proto.data_type), + _client=sift_client, + ) + + def _to_config_proto(self) -> ChannelConfig: + return ChannelConfig( + name=self.name, + data_type=self.data_type.value, + description=self.description, # type: ignore + unit=self.unit, # type: ignore + bit_field_elements=[el._to_proto() for el in self.bit_field_elements] + if self.bit_field_elements + else None, + enum_types=self._enum_types_to_proto_list(self.enum_types), + ) + + def data( + self, + *, + run_id: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + limit: int | None = None, + as_arrow: bool = False, + ): + """ + Retrieve channel data for this channel during the specified run. + + Args: + run_id: The run ID to get data for. + start_time: The start time to get data for. + end_time: The end time to get data for. + limit: The maximum number of data points to return. + + Returns: + A ChannelTimeSeries object. + """ + if as_arrow: + data = self.client.channels.get_data_as_arrow( + channels=[self], + run_id=run_id, + start_time=start_time, + end_time=end_time, + limit=limit, # type: ignore + ) + else: + data = self.client.channels.get_data( + channels=[self], + run_id=run_id, + start_time=start_time, + end_time=end_time, + limit=limit, # type: ignore + ) + return data + + @property + def asset(self) -> Asset: + return self.client.assets.get(asset_id=self.asset_id) + + @property + def runs(self) -> List[Run]: + return self.asset.runs class ChannelReference(BaseModel): diff --git a/python/lib/sift_client/types/ingestion.py b/python/lib/sift_client/types/ingestion.py new file mode 100644 index 000000000..7c744651f --- /dev/null +++ b/python/lib/sift_client/types/ingestion.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +import math +from datetime import datetime +from typing import TYPE_CHECKING, Any, List + +from google.protobuf.empty_pb2 import Empty +from pydantic import ConfigDict +from sift.ingest.v1.ingest_pb2 import IngestWithConfigDataChannelValue +from sift.ingestion_configs.v2.ingestion_configs_pb2 import ( + FlowConfig, +) +from sift.ingestion_configs.v2.ingestion_configs_pb2 import ( + IngestionConfig as IngestionConfigProto, +) +from sift_stream_bindings import ( + ChannelBitFieldElementPy, + ChannelConfigPy, + ChannelDataTypePy, + ChannelEnumTypePy, + FlowConfigPy, + IngestWithConfigDataChannelValuePy, +) + +from sift_client.types._base import BaseType +from sift_client.types.channel import Channel, ChannelDataType + +if TYPE_CHECKING: + from sift_client.client import SiftClient + + +class IngestionConfig(BaseType[IngestionConfigProto, "IngestionConfig"]): + """ + Model of the Sift Ingestion Config. + """ + + asset_id: str + client_key: str + + @classmethod + def _from_proto( + cls, proto: IngestionConfigProto, sift_client: SiftClient | None = None + ) -> "IngestionConfig": + return cls( + id_=proto.ingestion_config_id, + asset_id=proto.asset_id, + client_key=proto.client_key, + _client=sift_client, + ) + + +class Flow(BaseType[FlowConfig, "Flow"]): + model_config = ConfigDict(frozen=False) + name: str + channels: List[Channel] + ingestion_config_id: str | None = None + run_id: str | None = None + + @classmethod + def _from_proto(cls, proto: FlowConfig, sift_client: SiftClient | None = None) -> Flow: + return cls( + name=proto.name, + channels=[Channel._from_proto(channel) for channel in proto.channels], + _client=sift_client, + ) + + def _to_proto(self) -> FlowConfig: + return FlowConfig( + name=self.name, + channels=[channel._to_config_proto() for channel in self.channels], + ) + + def _to_rust_config(self) -> FlowConfigPy: + return FlowConfigPy( + name=self.name, + channels=[_channel_to_rust_config(channel) for channel in self.channels], + ) + + def add_channel(self, channel: Channel): + if self.ingestion_config_id: + raise ValueError("Cannot add a channel to a flow after creation") + self.channels.append(channel) + + def ingest(self, *, timestamp: datetime, channel_values: dict[str, Any]): + if self.ingestion_config_id is None: + raise ValueError("Ingestion config ID is not set.") + self.client.ingestion.ingest( + flow=self, + timestamp=timestamp, + channel_values=channel_values, + ) + + +# Converter functions. +def _channel_to_rust_config(channel: Channel) -> ChannelConfigPy: + return ChannelConfigPy( + name=channel.name, + data_type=_to_rust_type(channel.data_type), + description=channel.description or "", + unit=channel.unit or "", + bit_field_elements=[ + ChannelBitFieldElementPy(name=bfe.name, index=bfe.index, bit_count=bfe.bit_count) + for bfe in channel.bit_field_elements or [] + ], + enum_types=[ + ChannelEnumTypePy(key=enum_key, name=enum_name) + for enum_name, enum_key in channel.enum_types.items() + ] + if channel.enum_types + else [], + ) + + +def _rust_channel_value_from_bitfield( + channel: Channel, value: Any +) -> IngestWithConfigDataChannelValuePy: + """Helper function to convert a bitfield value to a ChannelValuePy object. + + Args: + value: The value to convert to a ChannelValuePy object. + - A single int or bytes will be treated as representing bytes directly + - Dicts or list of ints will be treated as representing individual bitfield elements. + + Returns: + A ChannelValuePy object. + """ + assert channel.bit_field_elements is not None + # We expect individual ints or bytes to represent full bitfield values. + if isinstance(value, bytes) or isinstance(value, int): + cast_value = [value] if isinstance(value, int) else value + return IngestWithConfigDataChannelValuePy.bitfield(cast_value) + + # We expect a dict or list of ints to represent individual bitfield elements. + list_value = value + if isinstance(value, dict): + list_value = [value[field.name] for field in channel.bit_field_elements] + + if len(list_value) != len(channel.bit_field_elements): + raise ValueError( + f"Expected number of values passed as list to match number of bit field elements for {channel.name}, but got {len(list_value)}" + ) + + packed = 0 + for i, field in enumerate(channel.bit_field_elements): + packed |= list_value[i] << field.bit_count + byte_array = packed.to_bytes(math.ceil(packed.bit_length() / 8), "little") + return IngestWithConfigDataChannelValuePy.bitfield(byte_array) + + +def _to_rust_value(channel: Channel, value: Any) -> IngestWithConfigDataChannelValuePy: + if value is None: + return IngestWithConfigDataChannelValuePy.empty() + if channel.data_type == ChannelDataType.ENUM: + enum_name = value + enum_val = channel.enum_types.get(enum_name) + if enum_val is None: + # Try to find the enum value by value instead of string. + for enum_name, enum_key in channel.enum_types.items() if channel.enum_types else []: + if enum_key == value: + enum_name = enum_name + enum_val = enum_key + break + if enum_val is None: + raise ValueError( + f"Could not find enum value: {value} in enum options: {channel.enum_types}" + ) + return IngestWithConfigDataChannelValuePy.enum_value(enum_val) + elif channel.data_type == ChannelDataType.BIT_FIELD: + return _rust_channel_value_from_bitfield(channel, value) + elif channel.data_type == ChannelDataType.BOOL: + return IngestWithConfigDataChannelValuePy.bool(value) + elif channel.data_type == ChannelDataType.FLOAT: + return IngestWithConfigDataChannelValuePy.float(value) + elif channel.data_type == ChannelDataType.DOUBLE: + return IngestWithConfigDataChannelValuePy.double(value) + elif channel.data_type == ChannelDataType.INT_32: + return IngestWithConfigDataChannelValuePy.int32(value) + elif channel.data_type == ChannelDataType.INT_64: + return IngestWithConfigDataChannelValuePy.int64(value) + elif channel.data_type == ChannelDataType.UINT_32: + return IngestWithConfigDataChannelValuePy.uint32(value) + elif channel.data_type == ChannelDataType.UINT_64: + return IngestWithConfigDataChannelValuePy.uint64(value) + else: + raise ValueError(f"Invalid data type: {channel.data_type}") + + +def _to_rust_type(data_type: ChannelDataType) -> ChannelDataTypePy: + if data_type == ChannelDataType.DOUBLE: + return ChannelDataTypePy.Double + elif data_type == ChannelDataType.FLOAT: + return ChannelDataTypePy.Float + elif data_type == ChannelDataType.STRING: + return ChannelDataTypePy.String + elif data_type == ChannelDataType.ENUM: + return ChannelDataTypePy.Enum + elif data_type == ChannelDataType.BIT_FIELD: + return ChannelDataTypePy.BitField + elif data_type == ChannelDataType.BOOL: + return ChannelDataTypePy.Bool + elif data_type == ChannelDataType.INT_32: + return ChannelDataTypePy.Int32 + elif data_type == ChannelDataType.INT_64: + return ChannelDataTypePy.Int64 + elif data_type == ChannelDataType.UINT_32: + return ChannelDataTypePy.Uint32 + elif data_type == ChannelDataType.UINT_64: + return ChannelDataTypePy.Uint64 + raise ValueError(f"Unknown data type: {data_type}") + + +def _to_ingestion_value(type: ChannelDataType, value: Any) -> IngestWithConfigDataChannelValue: + if value is None: + return IngestWithConfigDataChannelValue(empty=Empty()) + ingestion_type_string = type.name.lower().replace("int_", "int") + return IngestWithConfigDataChannelValue(**{ingestion_type_string: value}) diff --git a/python/lib/sift_client/types/rule.py b/python/lib/sift_client/types/rule.py new file mode 100644 index 000000000..b11cddf79 --- /dev/null +++ b/python/lib/sift_client/types/rule.py @@ -0,0 +1,322 @@ +from __future__ import annotations + +from datetime import datetime +from enum import Enum +from typing import TYPE_CHECKING, List, Optional, Type + +from sift.rules.v1.rules_pb2 import ( + ActionKind, + AnnotationActionConfiguration, + CalculatedChannelConfig, + RuleActionConfiguration, + UpdateActionRequest, +) + +# Extract nested class. +ChannelReferencesEntry = CalculatedChannelConfig.ChannelReferencesEntry +del CalculatedChannelConfig + +from sift.rules.v1.rules_pb2 import ( + Rule as RuleProto, +) +from sift.rules.v1.rules_pb2 import ( + RuleAction as RuleActionProto, +) +from sift.rules.v1.rules_pb2 import ( + RuleVersion as RuleVersionProto, +) + +from sift_client.types._base import BaseType, ModelUpdate +from sift_client.types.asset import Asset +from sift_client.types.channel import ChannelReference + +if TYPE_CHECKING: + from sift_client.client import SiftClient + + +class Rule(BaseType[RuleProto, "Rule"]): + """ + Model of the Sift Rule. + """ + + name: str + description: str + is_enabled: bool = True + expression: str | None = None + channel_references: List[ChannelReference] | None = None + action: RuleAction | None = None + asset_ids: List[str] | None = None + asset_tag_ids: List[str] | None = None + contextual_channels: List[str] | None = None + client_key: str | None = None + + # Fields from proto + created_date: datetime | None = None + modified_date: datetime | None = None + created_by_user_id: str | None = None + modified_by_user_id: str | None = None + organization_id: str | None = None + rule_version: RuleVersion | None = None + archived_date: datetime | None = None + is_external: bool | None = None + + @property + def is_archived(self) -> bool: + """Whether the rule is archived.""" + return self.archived_date is not None and self.archived_date > datetime(1970, 1, 1) + + @property + def assets(self) -> List[Asset]: + """Get the assets that this rule applies to.""" + return self.client.assets.list_(asset_ids=self.asset_ids, tag_ids=self.asset_tag_ids) + + @property + def organization(self): + """Get the organization that this rule belongs to.""" + raise NotImplementedError("Organization is not supported yet.") + + @property + def created_by(self): + """Get the user that created this rule.""" + raise NotImplementedError("Created by is not supported yet.") + + @property + def modified_by(self): + """Get the user that modified this rule.""" + raise NotImplementedError("Modified by is not supported yet.") + + @property + def tags(self): + """Get the tags that this rule applies to.""" + raise NotImplementedError("Tags is not supported yet.") + + def update(self, update: RuleUpdate | dict, version_notes: str | None = None) -> Rule: + """ + Update the Rule. + + Args: + update: Either a RuleUpdate instance or a dictionary of key-value pairs to update. + """ + updated_rule = self.client.rules.update( + rule=self, update=update, version_notes=version_notes + ) + self._update(updated_rule) + return self + + def archive(self) -> None: + """Archive the rule.""" + self.client.rules.archive(rule=self) + + @classmethod + def _from_proto(cls, proto: RuleProto, sift_client: SiftClient | None = None) -> Rule: + expression = ( + proto.conditions[0].expression.calculated_channel.expression + if proto.conditions + else None + ) + return cls( + id_=proto.rule_id, + name=proto.name, + description=proto.description, + expression=expression, + channel_references=[ + ChannelReference(channel_reference=ref, channel_identifier=c.name) + for ref, c in proto.conditions[ + 0 + ].expression.calculated_channel.channel_references.items() + ], + action=RuleAction._from_proto(proto.conditions[0].actions[0]), + is_enabled=proto.is_enabled, + created_date=proto.created_date.ToDatetime(), + modified_date=proto.modified_date.ToDatetime(), + created_by_user_id=proto.created_by_user_id, + modified_by_user_id=proto.modified_by_user_id, + organization_id=proto.organization_id, + rule_version=( + RuleVersion._from_proto(proto.rule_version) if proto.rule_version else None + ), + client_key=proto.client_key if proto.client_key else None, + asset_ids=proto.asset_configuration.asset_ids, # type: ignore + asset_tag_ids=proto.asset_configuration.tag_ids, # type: ignore + contextual_channels=[c.name for c in proto.contextual_channels.channels], + archived_date=proto.deleted_date.ToDatetime() if proto.deleted_date else None, + is_external=proto.is_external, + _client=sift_client, + ) + + +class RuleUpdate(ModelUpdate[RuleProto]): + """ + Model of the Rule fields that can be updated. + + Note: + - asset_ids applies this rule to those assets. + - asset_tag_ids applies this rule to assets with those tags. + """ + + name: str | None = None + description: str | None = None + expression: str | None = None + channel_references: List[ChannelReference] | None = None + action: RuleAction | None = None + asset_ids: List[str] | None = None + asset_tag_ids: List[str] | None = None + contextual_channels: List[str] | None = None + + def _get_proto_class(self) -> Type[RuleProto]: + return RuleProto + + def _add_resource_id_to_proto(self, proto_msg: RuleProto): + if self._resource_id is None: + raise ValueError("Resource ID must be set before adding to proto") + proto_msg.rule_id = self._resource_id + + +class RuleActionType(Enum): + """Enum for rule action kinds.""" + + UNSPECIFIED = ActionKind.ACTION_KIND_UNSPECIFIED # 0 + ANNOTATION = ActionKind.ANNOTATION # 1 + WEBHOOK = ActionKind.WEBHOOK # 2 + + @classmethod + def from_str(cls, val: str) -> Optional["RuleActionType"]: + if isinstance(val, str) and val.startswith("ACTION_KIND_"): + for item in cls: + if "ACTION_KIND_" + item.name == val: + return item + + return cls(int(val)) + + +class RuleAnnotationType(Enum): + """Enum for rule annotation types.""" + + UNSPECIFIED = 0 + DATA_REVIEW = 1 + PHASE = 2 + + @classmethod + def from_str(cls, val: str) -> Optional["RuleAnnotationType"]: + if isinstance(val, str) and val.startswith("ANNOTATION_TYPE_"): + for item in cls: + if "ANNOTATION_TYPE_" + item.name == val: + return item + + return cls(int(val)) + + +class RuleAction(BaseType[RuleActionProto, "RuleAction"]): + """ + Model of a Rule Action. + """ + + action_type: RuleActionType + condition_id: str | None = None + created_date: datetime | None = None + modified_date: datetime | None = None + created_by_user_id: str | None = None + modified_by_user_id: str | None = None + version_id: str | None = None + annotation_type: RuleAnnotationType | None = None + tags: List[str] | None = None + default_assignee_user_id: str | None = None + + @classmethod + def annotation( + cls, + annotation_type: RuleAnnotationType, + tags: List[str], + default_assignee_user_id: str | None = None, + ) -> RuleAction: + """Create an annotation action. + + Args: + annotation_type: Type of annotation to create. + default_assignee_user_id: User ID to assign the annotation to. + tags: List of tag IDs to add to the annotation. + """ + return cls( + action_type=RuleActionType.ANNOTATION, + annotation_type=annotation_type, + tags=tags, + default_assignee_user_id=default_assignee_user_id, + ) + + @classmethod + def _from_proto( + cls, proto: RuleActionProto, sift_client: SiftClient | None = None + ) -> RuleAction: + action_type = RuleActionType(proto.action_type) + return cls( + condition_id=proto.rule_condition_id, + created_date=proto.created_date.ToDatetime(), + modified_date=proto.modified_date.ToDatetime(), + created_by_user_id=proto.created_by_user_id, + modified_by_user_id=proto.modified_by_user_id, + version_id=proto.rule_action_version_id, + tags=( + list(proto.configuration.annotation.tag_ids) + if proto.configuration.annotation.tag_ids + else None + ), + default_assignee_user_id=( + proto.configuration.annotation.assigned_to_user_id + if proto.configuration.annotation.assigned_to_user_id + else None + ), + action_type=action_type, + annotation_type=RuleAnnotationType.from_str( + proto.configuration.annotation.annotation_type # type: ignore + ) + if action_type == RuleActionType.ANNOTATION + else None, + _client=sift_client, + ) + + def _to_update_request(self) -> UpdateActionRequest: + return UpdateActionRequest( + action_type=self.action_type.value, + configuration=RuleActionConfiguration( + annotation=( + AnnotationActionConfiguration( + assigned_to_user_id=self.default_assignee_user_id, + tag_ids=self.tags, + annotation_type=self.annotation_type.value, # type: ignore + ) + if self.action_type == RuleActionType.ANNOTATION + else None + ), + ), + ) + + +class RuleVersion(BaseType[RuleVersionProto, "RuleVersion"]): + """ + Model of a Rule Version. + """ + + rule_id: str + rule_version_id: str + version: str + created_date: datetime + created_by_user_id: str + version_notes: str + generated_change_message: str + deleted_date: datetime | None = None + + @classmethod + def _from_proto( + cls, proto: RuleVersionProto, sift_client: SiftClient | None = None + ) -> RuleVersion: + return cls( + rule_id=proto.rule_id, + rule_version_id=proto.rule_version_id, + version=proto.version, + created_date=proto.created_date.ToDatetime(), + created_by_user_id=proto.created_by_user_id, + version_notes=proto.version_notes, + generated_change_message=proto.generated_change_message, + deleted_date=proto.deleted_date.ToDatetime() if proto.deleted_date else None, + _client=sift_client, + ) diff --git a/python/lib/sift_client/types/run.py b/python/lib/sift_client/types/run.py index 61ec6fbf5..bd6ca3f93 100644 --- a/python/lib/sift_client/types/run.py +++ b/python/lib/sift_client/types/run.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from sift_client.client import SiftClient + from sift_client.types.asset import Asset class RunUpdate(ModelUpdate[RunProto]): @@ -51,7 +52,6 @@ class Run(BaseType[RunProto, "Run"]): model_config = ConfigDict(arbitrary_types_allowed=True) - id: str name: str description: str created_date: datetime @@ -71,7 +71,7 @@ class Run(BaseType[RunProto, "Run"]): @classmethod def _from_proto(cls, proto: RunProto, sift_client: SiftClient | None = None) -> Run: return cls( - id=proto.run_id, + id_=proto.run_id, created_date=proto.created_date.ToDatetime(), modified_date=proto.modified_date.ToDatetime(), created_by_user_id=proto.created_by_user_id, @@ -92,12 +92,12 @@ def _from_proto(cls, proto: RunProto, sift_client: SiftClient | None = None) -> _client=sift_client, ) - def to_proto(self) -> RunProto: + def _to_proto(self) -> RunProto: """ Convert to protobuf message. """ proto = RunProto( - run_id=self.id, + run_id=self.id_ or "", created_date=self.created_date, # type: ignore modified_date=self.modified_date, # type: ignore created_by_user_id=self.created_by_user_id, @@ -128,7 +128,8 @@ def to_proto(self) -> RunProto: return proto - def assets(self): + @property + def assets(self) -> List[Asset]: """ Return all assets associated with this run. """ @@ -137,11 +138,3 @@ def assets(self): if not self.asset_ids: return [] return self.client.assets.list_(asset_ids=self.asset_ids) - - def stop(self): - """ - Stop the run. - """ - if not hasattr(self, "client") or self.client is None: - raise RuntimeError("Run is not bound to a client instance.") - self.client.runs.stop_run(self.id) diff --git a/python/lib/sift_client/util/timestamp.py b/python/lib/sift_client/util/timestamp.py index 8a67a3dcf..96844560c 100644 --- a/python/lib/sift_client/util/timestamp.py +++ b/python/lib/sift_client/util/timestamp.py @@ -1,9 +1,17 @@ from datetime import datetime from google.protobuf.timestamp_pb2 import Timestamp +from sift_stream_bindings import TimeValuePy def to_pb_timestamp(timestamp: datetime) -> Timestamp: timestamp_pb = Timestamp() timestamp_pb.FromDatetime(timestamp) return timestamp_pb + + +def to_rust_py_timestamp(time: datetime) -> TimeValuePy: + ts = time.timestamp() + secs = int(ts) + nsecs = int((ts - secs) * 1_000_000_000) + return TimeValuePy.from_timestamp(secs, nsecs) diff --git a/python/lib/sift_client/util/util.py b/python/lib/sift_client/util/util.py index d9f74ab5a..1c83ccc58 100644 --- a/python/lib/sift_client/util/util.py +++ b/python/lib/sift_client/util/util.py @@ -5,7 +5,10 @@ from sift_client.resources import ( AssetsAPIAsync, CalculatedChannelsAPIAsync, + ChannelsAPIAsync, + IngestionAPIAsync, PingAPIAsync, + RulesAPIAsync, RunsAPIAsync, ) @@ -22,5 +25,14 @@ class AsyncAPIs(NamedTuple): calculated_channels: CalculatedChannelsAPIAsync """Instance of the Calculated Channels API for making asynchronous requests.""" + channels: ChannelsAPIAsync + """Instance of the Channels API for making asynchronous requests.""" + + ingestion: IngestionAPIAsync + """Instance of the Ingestion API for making asynchronous requests.""" + runs: RunsAPIAsync """Instance of the Runs API for making asynchronous requests.""" + + rules: RulesAPIAsync + """Instance of the Rules API for making asynchronous requests.""" diff --git a/python/lib/sift_py/grpc/transport.py b/python/lib/sift_py/grpc/transport.py index b6aff1438..07d13f667 100644 --- a/python/lib/sift_py/grpc/transport.py +++ b/python/lib/sift_py/grpc/transport.py @@ -118,7 +118,7 @@ def _use_insecure_sift_async_channel( FOR DEVELOPMENT PURPOSES ONLY """ return grpc_aio.insecure_channel( - target=config["uri"], + target=_clean_uri(config["uri"], False), options=_compute_channel_options(config), interceptors=_compute_sift_async_interceptors(config, metadata), ) diff --git a/python/pyproject.toml b/python/pyproject.toml index fe5735604..b3556eb3a 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -47,6 +47,7 @@ Changelog = "https://github.com/sift-stack/sift/tree/main/python/CHANGELOG.md" development = [ "grpcio-testing~=1.13", "mypy==1.10.0", + "pyarrow>=17.0.0", # sift_client, older version to support py3.8 "pyright==1.1.386", "pytest==8.2.2", "pytest-asyncio==0.23.7", @@ -59,9 +60,8 @@ docs = ["mkdocs", "mkdocs-material", "mkdocstrings[python]", "mkdocs-include-mar openssl = ["pyOpenSSL<24.0.0", "types-pyOpenSSL<24.0.0", "cffi~=1.14"] tdms = ["npTDMS~=1.9"] rosbags = ["rosbags~=0.0"] +sift-stream = ["sift-stream-bindings>=0.1.2"] hdf5 = ["h5py~=3.11", "polars~=1.8"] - - [build-system] requires = ["setuptools"] build-backend = "setuptools.build_meta" @@ -125,6 +125,11 @@ module = "grpc.aio" ignore_missing_imports = true ignore_errors = true +[[tool.mypy.overrides]] +module = "pyarrow" +ignore_missing_imports = true +ignore_errors = true + [[tool.mypy.overrides]] module = "requests_toolbelt" ignore_missing_imports = true