-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathdata.py
More file actions
375 lines (328 loc) · 14.6 KB
/
data.py
File metadata and controls
375 lines (328 loc) · 14.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
from __future__ import annotations
import asyncio
import logging
from datetime import datetime, timezone
from math import ceil
from typing import TYPE_CHECKING, Any, cast
import pandas as pd
from pydantic import BaseModel, ConfigDict
from sift.data.v2.data_pb2 import (
BitFieldValues,
ChannelQuery,
GetDataRequest,
GetDataResponse,
Query,
)
from sift.data.v2.data_pb2_grpc import DataServiceStub
from sift_py._internal.time import to_timestamp_nanos
from sift_client._internal.low_level_wrappers.base import LowLevelClientBase
from sift_client.sift_types.channel import Channel, ChannelDataType
from sift_client.transport import WithGrpcClient
if TYPE_CHECKING:
from sift_client.transport.grpc_transport import GrpcClient
# Configure logging
logger = logging.getLogger(__name__)
CHANNELS_DEFAULT_PAGE_SIZE = 10_000
# TODO: There is a pagination issue API side when requesting multiple channels in single request.
# If all data points for all channels in a single request don't fit into a single page, then
# paging seems to omit all but a single channel. We can increase this batch size once that issue
# has been resolved. In the mean time each channel gets its own request.
REQUEST_BATCH_SIZE = 1
class ChannelCacheEntry(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
data: pd.DataFrame
start_time: datetime
end_time: datetime
class ChannelCache(BaseModel):
name_id_map: dict[str, str]
channels: dict[str, ChannelCacheEntry]
class DataLowLevelClient(LowLevelClientBase, WithGrpcClient):
"""Low-level client for fetching channel data.
This class provides a thin wrapper around the autogenerated bindings for the DataAPI.
"""
channel_cache: ChannelCache = ChannelCache(name_id_map={}, channels={})
def __init__(self, grpc_client: GrpcClient):
"""Initialize the DataLowLevelClient.
Args:
grpc_client: The gRPC client to use for making API calls.
"""
super().__init__(grpc_client)
def _update_name_id_map(self, channels: list[Channel]):
"""Update the name id map with the new channels."""
for channel in channels:
if channel.bit_field_elements:
for bit_field_element in channel.bit_field_elements:
self.channel_cache.name_id_map[channel.name + "." + bit_field_element.name] = (
str(channel.id_)
)
self.channel_cache.name_id_map[channel.name] = str(channel.id_)
async def _get_data_impl(
self,
*,
channel_ids: list[str],
run_id: str | None = None,
start_time: datetime | None = None,
end_time: datetime,
page_size: int | None = None,
page_token: str | None = None,
order_by: str | None = None,
use_cache: bool = False,
force_refresh: bool = False,
cache_ttl: int | None = None,
) -> tuple[list[Any], str | None]:
"""Get the data for a channel during a run.
Args:
channel_ids: List of channel IDs to fetch data for.
run_id: Optional run ID to filter data.
start_time: Optional start time for the data range.
end_time: End time for the data range.
page_size: Number of results per page.
page_token: Token for pagination.
order_by: Field to order results by.
use_cache: Whether to enable caching for this request. Default: False.
force_refresh: Whether to force refresh the cache. Default: False.
cache_ttl: Optional custom TTL in seconds for cached responses.
Returns:
Tuple of (data list, next page token).
"""
queries = [
Query(channel=ChannelQuery(channel_id=channel_id, run_id=run_id))
for channel_id in channel_ids
]
request_kwargs: dict[str, Any] = {
"queries": queries,
"sample_ms": 0,
"start_time": start_time,
"end_time": end_time,
"page_size": page_size,
"page_token": page_token,
}
request = GetDataRequest(**request_kwargs)
# Use cache helper if caching is enabled
if use_cache or force_refresh:
response = await self._call_with_cache(
self._grpc_client.get_stub(DataServiceStub).GetData,
request,
use_cache=use_cache,
force_refresh=force_refresh,
ttl=cache_ttl,
)
else:
response = await self._grpc_client.get_stub(DataServiceStub).GetData(request)
response = cast("GetDataResponse", response)
return response.data, response.next_page_token # type: ignore # mypy doesn't know RepeatedCompositeFieldContainer can be treated like a list
def _filter_cached_channels(self, channel_ids: list[str]) -> tuple[list[str], list[str]]:
cached_channels = []
not_cached_channels = []
for channel_id in channel_ids:
if self.channel_cache.channels.get(channel_id):
cached_channels.append(channel_id)
else:
not_cached_channels.append(channel_id)
return cached_channels, not_cached_channels
def _check_cache(
self,
*,
channel_id: str,
start_time: datetime,
end_time: datetime,
run_id: str | None = None,
) -> tuple[pd.DataFrame | None, datetime | None, datetime | None]:
"""Check if the data for a channel during a run is cached and return how to query remaining data if so.
There are a variety of requested start/end time vs cached start/end time cases to consider.
Below diagram represents time aligned ranges for each case:
Cache interval: |-------------------------------|
Case 1: |---------------------------|
Case 2: |--------------------------------|
Case 3: |----------|
Case 4: |--------------------------------|
Case 5: |------| or |-----------------------------------------|
Returns:
A tuple of (data, start_time, end_time)
where data is a pandas dataframe and start and end times are what should be used for the next call based on what is not covered by the cached data.
"""
cached_data = self.channel_cache.channels.get(channel_id)
ret_start_time = start_time
ret_end_time = end_time
ret_data = None
if cached_data:
start_time_cached = cached_data.start_time
end_time_cached = cached_data.end_time
ret_data = cached_data.data
# Filter data to desiredtime range
ret_data = ret_data[start_time:end_time] # type: ignore # mypy doesn't understand pandas that well seemingly
if start_time_cached <= start_time:
if start_time < end_time_cached:
if end_time <= end_time_cached:
# Case 1
ret_start_time = None # type: ignore
ret_end_time = None # type: ignore
else:
# Case 2
ret_start_time = end_time_cached
ret_end_time = end_time
else:
# Case 3
return (None, start_time, end_time)
else:
if start_time_cached < end_time and end_time <= end_time_cached:
# Case 4
ret_start_time = start_time
ret_end_time = start_time_cached
else:
# Case 5
return (None, start_time, end_time)
return (ret_data, ret_start_time, ret_end_time)
def _update_cache(
self,
*,
channel_data: dict[str, pd.DataFrame],
start_time: datetime,
end_time: datetime,
run_id: str | None = None,
):
"""Update the cache with the new data and start/end times."""
assert start_time is not None
assert end_time is not None
name_id_map = self.channel_cache.name_id_map
for channel_name, data in channel_data.items():
channel_id = name_id_map.get(channel_name)
if not channel_id:
raise ValueError(
f"{channel_name} not found in name_id_map. Not sure got data for this channel without a call that should've updated the map."
)
suggested_start_time = start_time
if run_id:
if len(data) > 0:
suggested_start_time = data.index[0]
else:
# Because we didn't get any data, we can't know what the start time should be.
# And because this was queried w/ a run ID, we can't say there's no data before the run started.
# So we just don't update the cache.
continue
if channel_id in self.channel_cache.channels:
self.channel_cache.channels[channel_id].data = (
pd.concat([self.channel_cache.channels[channel_id].data, data])
.groupby(level=0)
.last()
)
self.channel_cache.channels[channel_id].start_time = min(
suggested_start_time, self.channel_cache.channels[channel_id].start_time
)
self.channel_cache.channels[channel_id].end_time = max(
end_time, self.channel_cache.channels[channel_id].end_time
)
else:
self.channel_cache.channels[channel_id] = ChannelCacheEntry(
data=data,
start_time=suggested_start_time,
end_time=end_time,
)
async def get_channel_data(
self,
*,
channels: list[Channel],
run_id: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
limit: int | None = None,
ignore_cache: bool = False,
) -> dict[str, pd.DataFrame]:
"""Get the data for a channel during a run."""
ret_data = {}
# No data will be returned if end_time is not provided.
start_time = start_time or datetime.fromtimestamp(0, tz=timezone.utc)
end_time = end_time or datetime.now(timezone.utc)
self._update_name_id_map(channels)
channel_ids = [c.id_ for c in channels]
cached_channels, not_cached_channels = (
([], channel_ids) if ignore_cache else self._filter_cached_channels(channel_ids) # type: ignore
)
tasks = []
page_size = limit if limit and limit < 1000 else 1000
limit = ceil(limit / page_size) if limit else 10
# Queue up calls for non-cached channels in batches.
batch_size = REQUEST_BATCH_SIZE
for i in range(0, len(not_cached_channels), batch_size): # type: ignore
batch = not_cached_channels[i : i + batch_size] # type: ignore
task = asyncio.create_task(
self._handle_pagination(
self._get_data_impl,
kwargs={
"channel_ids": batch,
"run_id": run_id,
"start_time": start_time,
"end_time": end_time,
},
page_size=page_size,
max_results=limit,
)
)
tasks.append(task)
# Handling cached channels 1 by 1 instead of in batches to account for channels that may have been cached from calls with different start/end times.
for channel_id in cached_channels:
cached_data, new_start_time, new_end_time = self._check_cache(
channel_id=channel_id,
start_time=start_time,
end_time=end_time,
run_id=run_id,
)
if cached_data is not None:
for name in cached_data.columns:
ret_data[name] = cached_data
if new_start_time is None:
# Cache fully encompassed the desired time range so don't queue a call.
continue
task = asyncio.create_task(
self._handle_pagination(
self._get_data_impl,
kwargs={
"channel_ids": [channel_id],
"run_id": run_id,
"start_time": new_start_time,
"end_time": new_end_time or end_time,
},
page_size=page_size,
max_results=limit,
)
)
tasks.append(task)
pages = await asyncio.gather(*tasks)
# Flatten the data
for page in pages:
for data in page:
page_results = self.try_deserialize_channel_data(data)
for name, df in page_results.items():
if name not in ret_data:
ret_data[name] = df
else:
ret_data[name] = pd.concat([ret_data[name], df]).groupby(level=0).last()
self._update_cache(
channel_data=ret_data, start_time=start_time, end_time=end_time, run_id=run_id
)
return ret_data
@staticmethod
def try_deserialize_channel_data(channel_data: Any) -> dict[str, pd.DataFrame]:
"""Deserialize a channel data object into a numpy array."""
data_type = ChannelDataType.from_str(channel_data.type_url)
if data_type is None:
raise ValueError(f"Unknown data type: {channel_data.type_url}")
proto_data_class = ChannelDataType.proto_data_class(data_type)
proto_data_value = proto_data_class.FromString(channel_data.value)
metadata = proto_data_value.metadata
ret_data = {}
components = (
proto_data_value.values if proto_data_class is BitFieldValues else [proto_data_value]
)
for component in components:
name = metadata.channel.name
time_column = []
value_column = []
if proto_data_class is BitFieldValues:
name += "." + component.name
for value_obj in component.values:
time_column.append(to_timestamp_nanos(value_obj.timestamp))
value_column.append(value_obj.value)
df = pd.DataFrame({name: value_column}, index=time_column)
ret_data[name] = df
return ret_data