-
Notifications
You must be signed in to change notification settings - Fork 44
Expand file tree
/
Copy pathper_partition_cursor.py
More file actions
365 lines (311 loc) · 16.5 KB
/
per_partition_cursor.py
File metadata and controls
365 lines (311 loc) · 16.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import logging
from collections import OrderedDict
from typing import Any, Callable, Iterable, Mapping, Optional, Union
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
from airbyte_cdk.sources.streams.checkpoint.per_partition_key_serializer import (
PerPartitionKeySerializer,
)
from airbyte_cdk.sources.types import Record, StreamSlice, StreamState
logger = logging.getLogger("airbyte")
class CursorFactory:
def __init__(self, create_function: Callable[[], DeclarativeCursor]):
self._create_function = create_function
def create(self) -> DeclarativeCursor:
return self._create_function()
class PerPartitionCursor(DeclarativeCursor):
"""
Manages state per partition when a stream has many partitions, to prevent data loss or duplication.
**Partition Limitation and Limit Reached Logic**
- **DEFAULT_MAX_PARTITIONS_NUMBER**: The maximum number of partitions to keep in memory (default is 10,000).
- **_cursor_per_partition**: An ordered dictionary that stores cursors for each partition.
- **_over_limit**: A counter that increments each time an oldest partition is removed when the limit is exceeded.
The class ensures that the number of partitions tracked does not exceed the `DEFAULT_MAX_PARTITIONS_NUMBER` to prevent excessive memory usage.
- When the number of partitions exceeds the limit, the oldest partitions are removed from `_cursor_per_partition`, and `_over_limit` is incremented accordingly.
- The `limit_reached` method returns `True` when `_over_limit` exceeds `DEFAULT_MAX_PARTITIONS_NUMBER`, indicating that the global cursor should be used instead of per-partition cursors.
This approach avoids unnecessary switching to a global cursor due to temporary spikes in partition counts, ensuring that switching is only done when a sustained high number of partitions is observed.
"""
DEFAULT_MAX_PARTITIONS_NUMBER = 10000
_NO_STATE: Mapping[str, Any] = {}
_NO_CURSOR_STATE: Mapping[str, Any] = {}
_KEY = 0
_VALUE = 1
_state_to_migrate_from: Mapping[str, Any] = {}
def __init__(self, cursor_factory: CursorFactory, partition_router: PartitionRouter):
self._cursor_factory = cursor_factory
self._partition_router = partition_router
# The dict is ordered to ensure that once the maximum number of partitions is reached,
# the oldest partitions can be efficiently removed, maintaining the most recent partitions.
self._cursor_per_partition: OrderedDict[str, DeclarativeCursor] = OrderedDict()
self._over_limit = 0
self._partition_serializer = PerPartitionKeySerializer()
def stream_slices(self) -> Iterable[StreamSlice]:
slices = self._partition_router.stream_slices()
for partition in slices:
yield from self.generate_slices_from_partition(partition)
def generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[StreamSlice]:
# Ensure the maximum number of partitions is not exceeded
self._ensure_partition_limit()
cursor = self._cursor_per_partition.get(self._to_partition_key(partition.partition))
if not cursor:
partition_state = (
self._state_to_migrate_from
if self._state_to_migrate_from
else self._NO_CURSOR_STATE
)
cursor = self._create_cursor(partition_state)
self._cursor_per_partition[self._to_partition_key(partition.partition)] = cursor
for cursor_slice in cursor.stream_slices():
yield StreamSlice(
partition=partition, cursor_slice=cursor_slice, extra_fields=partition.extra_fields
)
def _ensure_partition_limit(self) -> None:
"""
Ensure the maximum number of partitions is not exceeded. If so, the oldest added partition will be dropped.
"""
while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1:
self._over_limit += 1
oldest_partition = self._cursor_per_partition.popitem(last=False)[
0
] # Remove the oldest partition
logger.warning(
f"The maximum number of partitions has been reached. Dropping the oldest partition: {oldest_partition}. Over limit: {self._over_limit}."
)
def limit_reached(self) -> bool:
return self._over_limit > self.DEFAULT_MAX_PARTITIONS_NUMBER
def set_initial_state(self, stream_state: StreamState) -> None:
"""
Set the initial state for the cursors.
This method initializes the state for each partition cursor using the provided stream state.
If a partition state is provided in the stream state, it will update the corresponding partition cursor with this state.
Additionally, it sets the parent state for partition routers that are based on parent streams. If a partition router
does not have parent streams, this step will be skipped due to the default PartitionRouter implementation.
Args:
stream_state (StreamState): The state of the streams to be set. The format of the stream state should be:
{
"states": [
{
"partition": {
"partition_key": "value"
},
"cursor": {
"last_updated": "2023-05-27T00:00:00Z"
}
}
],
"parent_state": {
"parent_stream_name": {
"last_updated": "2023-05-27T00:00:00Z"
}
}
}
"""
if not stream_state:
return
if "states" not in stream_state:
# We assume that `stream_state` is in a global format that can be applied to all partitions.
# Example: {"global_state_format_key": "global_state_format_value"}
self._state_to_migrate_from = stream_state
else:
for state in stream_state["states"]:
self._cursor_per_partition[self._to_partition_key(state["partition"])] = (
self._create_cursor(state["cursor"])
)
# set default state for missing partitions if it is per partition with fallback to global
if "state" in stream_state:
self._state_to_migrate_from = stream_state["state"]
# Set parent state for partition routers based on parent streams
self._partition_router.set_initial_state(stream_state)
def observe(self, stream_slice: StreamSlice, record: Record) -> None:
self._cursor_per_partition[self._to_partition_key(stream_slice.partition)].observe(
StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), record
)
def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None:
try:
self._cursor_per_partition[self._to_partition_key(stream_slice.partition)].close_slice(
StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), *args
)
except KeyError as exception:
raise ValueError(
f"Partition {str(exception)} could not be found in current state based on the record. This is unexpected because "
f"we should only update state for partitions that were emitted during `stream_slices`"
)
def get_stream_state(self) -> StreamState:
states = []
for partition_tuple, cursor in self._cursor_per_partition.items():
cursor_state = cursor.get_stream_state()
if cursor_state:
states.append(
{
"partition": self._to_dict(partition_tuple),
"cursor": cursor_state,
}
)
state: dict[str, Any] = {"states": states}
parent_state = self._partition_router.get_stream_state()
if parent_state:
state["parent_state"] = parent_state
return state
def _get_state_for_partition(self, partition: Mapping[str, Any]) -> Optional[StreamState]:
cursor = self._cursor_per_partition.get(self._to_partition_key(partition))
if cursor:
return cursor.get_stream_state()
return None
@staticmethod
def _is_new_state(stream_state: Mapping[str, Any]) -> bool:
return not bool(stream_state)
def _to_partition_key(self, partition: Mapping[str, Any]) -> str:
return self._partition_serializer.to_partition_key(partition)
def _to_dict(self, partition_key: str) -> Mapping[str, Any]:
return self._partition_serializer.to_partition(partition_key)
def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]:
if not stream_slice:
raise ValueError("A partition needs to be provided in order to extract a state")
if not stream_slice:
return None
return self._get_state_for_partition(stream_slice.partition)
def _create_cursor(self, cursor_state: Any) -> DeclarativeCursor:
cursor = self._cursor_factory.create()
cursor.set_initial_state(cursor_state)
return cursor
def get_request_params(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, Any]:
if stream_slice:
if self._to_partition_key(stream_slice.partition) not in self._cursor_per_partition:
self._create_cursor_for_partition(self._to_partition_key(stream_slice.partition))
return self._partition_router.get_request_params( # type: ignore # this always returns a mapping
stream_state=stream_state,
stream_slice=StreamSlice(partition=stream_slice.partition, cursor_slice={}),
next_page_token=next_page_token,
) | self._cursor_per_partition[
self._to_partition_key(stream_slice.partition)
].get_request_params(
stream_state=stream_state,
stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice),
next_page_token=next_page_token,
)
else:
raise ValueError("A partition needs to be provided in order to get request params")
def get_request_headers(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, Any]:
if stream_slice:
if self._to_partition_key(stream_slice.partition) not in self._cursor_per_partition:
self._create_cursor_for_partition(self._to_partition_key(stream_slice.partition))
return self._partition_router.get_request_headers( # type: ignore # this always returns a mapping
stream_state=stream_state,
stream_slice=StreamSlice(partition=stream_slice.partition, cursor_slice={}),
next_page_token=next_page_token,
) | self._cursor_per_partition[
self._to_partition_key(stream_slice.partition)
].get_request_headers(
stream_state=stream_state,
stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice),
next_page_token=next_page_token,
)
else:
raise ValueError("A partition needs to be provided in order to get request headers")
def get_request_body_data(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Union[Mapping[str, Any], str]:
if stream_slice:
if self._to_partition_key(stream_slice.partition) not in self._cursor_per_partition:
self._create_cursor_for_partition(self._to_partition_key(stream_slice.partition))
return self._partition_router.get_request_body_data( # type: ignore # this always returns a mapping
stream_state=stream_state,
stream_slice=StreamSlice(partition=stream_slice.partition, cursor_slice={}),
next_page_token=next_page_token,
) | self._cursor_per_partition[
self._to_partition_key(stream_slice.partition)
].get_request_body_data(
stream_state=stream_state,
stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice),
next_page_token=next_page_token,
)
else:
raise ValueError("A partition needs to be provided in order to get request body data")
def get_request_body_json(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, Any]:
if stream_slice:
if self._to_partition_key(stream_slice.partition) not in self._cursor_per_partition:
self._create_cursor_for_partition(self._to_partition_key(stream_slice.partition))
return self._partition_router.get_request_body_json( # type: ignore # this always returns a mapping
stream_state=stream_state,
stream_slice=StreamSlice(partition=stream_slice.partition, cursor_slice={}),
next_page_token=next_page_token,
) | self._cursor_per_partition[
self._to_partition_key(stream_slice.partition)
].get_request_body_json(
stream_state=stream_state,
stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice),
next_page_token=next_page_token,
)
else:
raise ValueError("A partition needs to be provided in order to get request body json")
def should_be_synced(self, record: Record) -> bool:
return self._get_cursor(record).should_be_synced(
self._convert_record_to_cursor_record(record)
)
@staticmethod
def _convert_record_to_cursor_record(record: Record) -> Record:
return Record(
data=record.data,
stream_name=record.stream_name,
associated_slice=StreamSlice(
partition={}, cursor_slice=record.associated_slice.cursor_slice
)
if record.associated_slice
else None,
)
def _get_cursor(self, record: Record) -> DeclarativeCursor:
if not record.associated_slice:
raise ValueError(
"Invalid state as stream slices that are emitted should refer to an existing cursor"
)
partition_key = self._to_partition_key(record.associated_slice.partition)
if partition_key not in self._cursor_per_partition:
self._create_cursor_for_partition(partition_key)
cursor = self._cursor_per_partition[partition_key]
return cursor
def _create_cursor_for_partition(self, partition_key: str) -> None:
"""
Dynamically creates and initializes a cursor for the specified partition.
This method is required for `ConcurrentPerPartitionCursor`. For concurrent cursors,
stream_slices is executed only for the concurrent cursor, so cursors per partition
are not created for the declarative cursor. This method ensures that a cursor is available
to create requests for the specified partition. The cursor is initialized
with the per-partition state if present in the initial state, or with the global state
adjusted by the lookback window, or with the state to migrate from.
Note:
This is a temporary workaround and should be removed once the declarative cursor
is decoupled from the concurrent cursor implementation.
Args:
partition_key (str): The unique identifier for the partition for which the cursor
needs to be created.
"""
partition_state = (
self._state_to_migrate_from if self._state_to_migrate_from else self._NO_CURSOR_STATE
)
cursor = self._create_cursor(partition_state)
self._cursor_per_partition[partition_key] = cursor