diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py index e212b0f2a..07d71e74f 100644 --- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py +++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py @@ -361,7 +361,7 @@ def _group_streams( == DatetimeBasedCursorModel.__name__ and hasattr(declarative_stream.retriever, "stream_slicer") and isinstance( - declarative_stream.retriever.stream_slicer, PerPartitionWithGlobalCursor + declarative_stream.retriever.stream_slicer, ConcurrentPerPartitionCursor ) ): stream_state = self._connector_state_manager.get_stream_state( @@ -369,20 +369,7 @@ def _group_streams( ) stream_state = self._migrate_state(declarative_stream, stream_state) - partition_router = declarative_stream.retriever.stream_slicer._partition_router - - perpartition_cursor = ( - self._constructor.create_concurrent_cursor_from_perpartition_cursor( - state_manager=self._connector_state_manager, - model_type=DatetimeBasedCursorModel, - component_definition=incremental_sync_component_definition, - stream_name=declarative_stream.name, - stream_namespace=declarative_stream.namespace, - config=config or {}, - stream_state=stream_state, - partition_router=partition_router, - ) - ) + perpartition_cursor = declarative_stream.retriever.stream_slicer retriever = self._get_retriever(declarative_stream, stream_state) @@ -464,15 +451,7 @@ def _get_retriever( if retriever.cursor: retriever.cursor.set_initial_state(stream_state=stream_state) - # Similar to above, the ClientSideIncrementalRecordFilterDecorator cursor is a separate instance - # from the one initialized on the SimpleRetriever, so it also must also have state initialized - # for semi-incremental streams using is_client_side_incremental to filter properly - if isinstance(retriever.record_selector, RecordSelector) and isinstance( - retriever.record_selector.record_filter, ClientSideIncrementalRecordFilterDecorator - ): - retriever.record_selector.record_filter._cursor.set_initial_state( - stream_state=stream_state - ) # type: ignore # After non-concurrent cursors are deprecated we can remove these cursor workarounds + # FIXME comment: Removing this as the concurrent state should already have the information # We zero it out here, but since this is a cursor reference, the state is still properly # instantiated for the other components that reference it diff --git a/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py b/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py index a0c541dc4..4bafdb3cc 100644 --- a/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py @@ -11,11 +11,18 @@ from datetime import timedelta from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional +from airbyte_cdk.models import ( + AirbyteStateBlob, + AirbyteStateMessage, + AirbyteStateType, + AirbyteStreamState, + StreamDescriptor, +) from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import ( Timer, iterate_with_last_flag_and_state, -) +) # FIXME since it relies on the declarative package, this can generate circular imports errors from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter from airbyte_cdk.sources.message import MessageRepository from airbyte_cdk.sources.streams.checkpoint.per_partition_key_serializer import ( @@ -150,6 +157,7 @@ def close_partition(self, partition: Partition) -> None: raise ValueError("stream_slice cannot be None") partition_key = self._to_partition_key(stream_slice.partition) + logger.warning(f"close_partition... Semaphore for {partition_key}") with self._lock: self._semaphore_per_partition[partition_key].acquire() if not self._use_global_cursor: @@ -204,6 +212,7 @@ def _check_and_update_parent_state(self) -> None: for p_key in list(self._semaphore_per_partition.keys()): sem = self._semaphore_per_partition[p_key] if p_key in self._finished_partitions and sem._value == 0: + logger.warning(f"_check_and_update_parent_state delete semaphore for {p_key}") del self._semaphore_per_partition[p_key] logger.debug(f"Deleted finished semaphore for partition {p_key} with value 0") if p_key == earliest_key: @@ -261,6 +270,7 @@ def stream_slices(self) -> Iterable[StreamSlice]: slices, self._partition_router.get_stream_state ): yield from self._generate_slices_from_partition(partition, parent_state) + self._parent_state = self._partition_router.get_stream_state() def _generate_slices_from_partition( self, partition: StreamSlice, parent_state: Mapping[str, Any] @@ -289,6 +299,7 @@ def _generate_slices_from_partition( ] != parent_state ): + print(f"GODO:\n\t{parent_state}") # FIXME parent state needs to be tracked in substream partition router self._partition_parent_state_map[partition_key] = deepcopy(parent_state) for cursor_slice, is_last_slice, _ in iterate_with_last_flag_and_state( @@ -296,6 +307,7 @@ def _generate_slices_from_partition( lambda: None, ): self._semaphore_per_partition[partition_key].release() + logger.warning(f"Generating... Semaphore for {partition_key} is {self._semaphore_per_partition[partition_key]._value}") if is_last_slice: self._finished_partitions.add(partition_key) yield StreamSlice( @@ -418,7 +430,7 @@ def _set_initial_state(self, stream_state: StreamState) -> None: self._parent_state = stream_state["parent_state"] # Set parent state for partition routers based on parent streams - self._partition_router.set_initial_state(stream_state) + self._partition_router.set_initial_state(stream_state) # FIXME can we remove this thing? this would probably be a breaking change though... def _set_global_state(self, stream_state: Mapping[str, Any]) -> None: """ @@ -489,10 +501,31 @@ def _get_cursor(self, record: Record) -> ConcurrentCursor: partition_key = self._to_partition_key(record.associated_slice.partition) if partition_key not in self._cursor_per_partition: raise ValueError( - "Invalid state as stream slices that are emitted should refer to an existing cursor" + f"Invalid state as stream slices that are emitted should refer to an existing cursor but {partition_key} is unknown" ) cursor = self._cursor_per_partition[partition_key] return cursor def limit_reached(self) -> bool: return self._number_of_partitions > self.SWITCH_TO_GLOBAL_LIMIT + + @staticmethod + def get_parent_state(stream_state: Optional[StreamState], parent_stream_name: str) -> Optional[AirbyteStateMessage]: + return AirbyteStateMessage( + type=AirbyteStateType.STREAM, + stream=AirbyteStreamState( + stream_descriptor=StreamDescriptor(parent_stream_name, None), + stream_state=AirbyteStateBlob(stream_state["parent_state"][parent_stream_name]) + ) + ) if stream_state and "parent_state" in stream_state else None + + @staticmethod + def get_global_state(stream_state: Optional[StreamState], parent_stream_name: str) -> Optional[AirbyteStateMessage]: + return AirbyteStateMessage( + type=AirbyteStateType.STREAM, + stream=AirbyteStreamState( + stream_descriptor=StreamDescriptor(parent_stream_name, None), + stream_state=AirbyteStateBlob(stream_state["state"]) + ) + ) if stream_state and "state" in stream_state else None + diff --git a/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py b/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py index 610a15bbd..32cd6e28d 100644 --- a/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py @@ -1,7 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # - +import logging import threading import time from typing import Any, Callable, Iterable, Mapping, Optional, TypeVar, Union @@ -12,7 +12,7 @@ from airbyte_cdk.sources.types import Record, StreamSlice, StreamState T = TypeVar("T") - +logger = logging.getLogger(__name__) def iterate_with_last_flag_and_state( generator: Iterable[T], get_stream_state_func: Callable[[], Optional[Mapping[str, StreamState]]] @@ -40,6 +40,7 @@ def iterate_with_last_flag_and_state( return # Return an empty iterator for next_item in iterator: + logger.info(f"slice: {current}, state: {state}") yield current, False, state current = next_item state = get_stream_state_func() diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index 25840f06f..43ca70537 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -7,6 +7,7 @@ import datetime import importlib import inspect +import logging import re from functools import partial from typing import ( @@ -27,7 +28,15 @@ from isodate import parse_duration from pydantic.v1 import BaseModel -from airbyte_cdk.models import FailureType, Level +from airbyte_cdk.models import ( + AirbyteStateBlob, + AirbyteStateMessage, + AirbyteStateType, + AirbyteStreamState, + FailureType, + Level, + StreamDescriptor, +) from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.declarative.async_job.job_orchestrator import AsyncJobOrchestrator from airbyte_cdk.sources.declarative.async_job.job_tracker import JobTracker @@ -94,9 +103,7 @@ CursorFactory, DatetimeBasedCursor, DeclarativeCursor, - GlobalSubstreamCursor, PerPartitionCursor, - PerPartitionWithGlobalCursor, ResumableFullRefreshCursor, ) from airbyte_cdk.sources.declarative.interpolation import InterpolatedString @@ -106,7 +113,6 @@ ) from airbyte_cdk.sources.declarative.models import ( CustomStateMigration, - GzipDecoder, ) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( AddedFieldDefinition as AddedFieldDefinitionModel, @@ -490,6 +496,10 @@ ) from airbyte_cdk.sources.declarative.spec import Spec from airbyte_cdk.sources.declarative.stream_slicers import StreamSlicer +from airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator import ( + DeclarativePartitionFactory, + StreamSlicerPartitionGenerator, +) from airbyte_cdk.sources.declarative.transformations import ( AddFields, RecordTransformation, @@ -526,6 +536,9 @@ Rate, UnlimitedCallRatePolicy, ) +from airbyte_cdk.sources.streams.concurrent.availability_strategy import ( + AlwaysAvailableAvailabilityStrategy, +) from airbyte_cdk.sources.streams.concurrent.clamping import ( ClampingEndProvider, ClampingStrategy, @@ -535,7 +548,14 @@ WeekClampingStrategy, Weekday, ) -from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, CursorField +from airbyte_cdk.sources.streams.concurrent.cursor import ( + ConcurrentCursor, + Cursor, + CursorField, + FinalStateCursor, +) +from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream +from airbyte_cdk.sources.streams.concurrent.helpers import get_primary_key_from_stream from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import ( CustomFormatConcurrentStreamStateConverter, DateTimeStreamStateConverter, @@ -576,7 +596,7 @@ def __init__( self._emit_connector_builder_messages = emit_connector_builder_messages self._disable_retries = disable_retries self._disable_cache = disable_cache - self._disable_resumable_full_refresh = disable_resumable_full_refresh + self._disable_resumable_full_refresh = disable_resumable_full_refresh # FIXME can we remove this? self._message_repository = message_repository or InMemoryMessageRepository( self._evaluate_log_level(emit_connector_builder_messages) ) @@ -1381,6 +1401,7 @@ def create_concurrent_cursor_from_perpartition_cursor( f"Expected manifest component of type {model_type.__name__}, but received {component_type} instead" ) + component_definition["$parameters"] = component_definition.get("parameters", {}) datetime_based_cursor_model = model_type.parse_obj(component_definition) if not isinstance(datetime_based_cursor_model, DatetimeBasedCursorModel): @@ -1390,7 +1411,7 @@ def create_concurrent_cursor_from_perpartition_cursor( interpolated_cursor_field = InterpolatedString.create( datetime_based_cursor_model.cursor_field, - parameters=datetime_based_cursor_model.parameters or {}, + parameters=component_definition.get("parameters", component_definition.get("$parameters", {})), # FIXME being called from `_group_stream`, parameters is not propagated ) cursor_field = CursorField(interpolated_cursor_field.eval(config=config)) @@ -1728,7 +1749,10 @@ def create_datetime_based_cursor( def create_declarative_stream( self, model: DeclarativeStreamModel, config: Config, **kwargs: Any - ) -> DeclarativeStream: + ) -> Union[DeclarativeStream, DefaultStream]: + """ + As a transition period, the stream returned can either be a DeclarativeStream or a DefaultStream depending on kwargs["concurrent"] which is a bool + """ # When constructing a declarative stream, we assemble the incremental_sync component and retriever's partition_router field # components if they exist into a single CartesianProductStreamSlicer. This is then passed back as an argument when constructing the # Retriever. This is done in the declarative stream not the retriever to support custom retrievers. The custom create methods in @@ -1747,24 +1771,11 @@ def create_declarative_stream( and hasattr(model.incremental_sync, "is_client_side_incremental") and model.incremental_sync.is_client_side_incremental ): - supported_slicers = ( - DatetimeBasedCursor, - GlobalSubstreamCursor, - PerPartitionWithGlobalCursor, - ) - if combined_slicers and not isinstance(combined_slicers, supported_slicers): + if combined_slicers and not isinstance(combined_slicers, Cursor): raise ValueError( - "Unsupported Slicer is used. PerPartitionWithGlobalCursor should be used here instead" + f"Unsupported slicer `{type(combined_slicers)}` is used with is_client_side_incremental" ) - cursor = ( - combined_slicers - if isinstance( - combined_slicers, (PerPartitionWithGlobalCursor, GlobalSubstreamCursor) - ) - else self._create_component_from_model(model=model.incremental_sync, config=config) - ) - - client_side_incremental_sync = {"cursor": cursor} + client_side_incremental_sync = {"cursor": combined_slicers} if model.incremental_sync and isinstance(model.incremental_sync, DatetimeBasedCursorModel): cursor_model = model.incremental_sync @@ -1859,6 +1870,30 @@ def create_declarative_stream( options["name"] = model.name schema_loader = DefaultSchemaLoader(config=config, parameters=options) + if "concurrent" in kwargs and kwargs["concurrent"]: + stream_name = model.name or "" + cursor = combined_slicers if combined_slicers and isinstance(combined_slicers, Cursor) else FinalStateCursor(stream_name, None, + self._message_repository) + partition_generator = StreamSlicerPartitionGenerator( + DeclarativePartitionFactory( + stream_name, + schema_loader.get_json_schema(), + retriever, + self._message_repository, + ), + cursor, + ) + + return DefaultStream( + partition_generator=partition_generator, + name=stream_name, + json_schema=schema_loader.get_json_schema(), + availability_strategy=AlwaysAvailableAvailabilityStrategy(), # FIXME it seems this is what we do in the ConcurrentDeclarativeSource but it feels wrong + primary_key=get_primary_key_from_stream(primary_key), + cursor_field=cursor.cursor_field.cursor_field_key if hasattr(cursor, "cursor_field") else "", # FIXME we should have the cursor field has part of the interface of cursor + logger=logging.getLogger(f"airbyte.{stream_name}"), # FIXME this is a breaking change compared to the old implementation, + cursor=cursor, + ) return DeclarativeStream( name=model.name or "", primary_key=primary_key, @@ -1908,52 +1943,29 @@ def _build_incremental_cursor( stream_slicer: Optional[PartitionRouter], config: Config, ) -> Optional[StreamSlicer]: + stream_state = self._connector_state_manager.get_stream_state( + stream_name=model.name, namespace=None + ) # FIXME should this be in create_concurrent_cursor_from_perpartition_cursor if model.incremental_sync and stream_slicer: - if model.retriever.type == "AsyncRetriever": - return self.create_concurrent_cursor_from_perpartition_cursor( # type: ignore # This is a known issue that we are creating and returning a ConcurrentCursor which does not technically implement the (low-code) StreamSlicer. However, (low-code) StreamSlicer and ConcurrentCursor both implement StreamSlicer.stream_slices() which is the primary method needed for checkpointing - state_manager=self._connector_state_manager, - model_type=DatetimeBasedCursorModel, - component_definition=model.incremental_sync.__dict__, - stream_name=model.name or "", - stream_namespace=None, - config=config or {}, - stream_state={}, - partition_router=stream_slicer, - ) - - incremental_sync_model = model.incremental_sync - cursor_component = self._create_component_from_model( - model=incremental_sync_model, config=config - ) - is_global_cursor = ( - hasattr(incremental_sync_model, "global_substream_cursor") - and incremental_sync_model.global_substream_cursor - ) - - if is_global_cursor: - return GlobalSubstreamCursor( - stream_cursor=cursor_component, partition_router=stream_slicer - ) - return PerPartitionWithGlobalCursor( - cursor_factory=CursorFactory( - lambda: self._create_component_from_model( - model=incremental_sync_model, config=config - ), - ), + return self.create_concurrent_cursor_from_perpartition_cursor( # type: ignore # This is a known issue that we are creating and returning a ConcurrentCursor which does not technically implement the (low-code) StreamSlicer. However, (low-code) StreamSlicer and ConcurrentCursor both implement StreamSlicer.stream_slices() which is the primary method needed for checkpointing + state_manager=self._connector_state_manager, + model_type=DatetimeBasedCursorModel, + component_definition=model.incremental_sync.__dict__, + stream_name=model.name or "", + stream_namespace=None, + config=config or {}, + stream_state=stream_state, partition_router=stream_slicer, - stream_cursor=cursor_component, ) elif model.incremental_sync: - if model.retriever.type == "AsyncRetriever": - return self.create_concurrent_cursor_from_datetime_based_cursor( # type: ignore # This is a known issue that we are creating and returning a ConcurrentCursor which does not technically implement the (low-code) StreamSlicer. However, (low-code) StreamSlicer and ConcurrentCursor both implement StreamSlicer.stream_slices() which is the primary method needed for checkpointing - model_type=DatetimeBasedCursorModel, - component_definition=model.incremental_sync.__dict__, - stream_name=model.name or "", - stream_namespace=None, - config=config or {}, - stream_state_migrations=model.state_migrations, - ) - return self._create_component_from_model(model=model.incremental_sync, config=config) # type: ignore[no-any-return] # Will be created Cursor as stream_slicer_model is model.incremental_sync + return self.create_concurrent_cursor_from_datetime_based_cursor( # type: ignore # This is a known issue that we are creating and returning a ConcurrentCursor which does not technically implement the (low-code) StreamSlicer. However, (low-code) StreamSlicer and ConcurrentCursor both implement StreamSlicer.stream_slices() which is the primary method needed for checkpointing + model_type=DatetimeBasedCursorModel, + component_definition=model.incremental_sync.__dict__, + stream_name=model.name or "", + stream_namespace=None, + config=config or {}, + stream_state_migrations=model.state_migrations, + ) return None def _build_resumable_cursor( @@ -1981,44 +1993,14 @@ def _build_resumable_cursor( def _merge_stream_slicers( self, model: DeclarativeStreamModel, config: Config ) -> Optional[StreamSlicer]: - retriever_model = model.retriever - stream_slicer = self._build_stream_slicer_from_partition_router( - retriever_model, config, stream_name=model.name + model.retriever, config, stream_name=model.name ) - if retriever_model.type == "AsyncRetriever": - is_not_datetime_cursor = ( - model.incremental_sync.type != "DatetimeBasedCursor" - if model.incremental_sync - else None - ) - is_partition_router = ( - bool(retriever_model.partition_router) if model.incremental_sync else None - ) - - if is_not_datetime_cursor: - # We are currently in a transition to the Concurrent CDK and AsyncRetriever can only work with the - # support or unordered slices (for example, when we trigger reports for January and February, the report - # in February can be completed first). Once we have support for custom concurrent cursor or have a new - # implementation available in the CDK, we can enable more cursors here. - raise ValueError( - "AsyncRetriever with cursor other than DatetimeBasedCursor is not supported yet." - ) - - if is_partition_router and not stream_slicer: - # Note that this development is also done in parallel to the per partition development which once merged - # we could support here by calling create_concurrent_cursor_from_perpartition_cursor - raise ValueError("Per partition state is not supported yet for AsyncRetriever.") - if model.incremental_sync: return self._build_incremental_cursor(model, stream_slicer, config) - return ( - stream_slicer - if self._disable_resumable_full_refresh - else self._build_resumable_cursor(retriever_model, stream_slicer) - ) + return stream_slicer def create_default_error_handler( self, model: DefaultErrorHandlerModel, config: Config, **kwargs: Any @@ -2611,7 +2593,7 @@ def create_parent_stream_config( self, model: ParentStreamConfigModel, config: Config, **kwargs: Any ) -> ParentStreamConfig: declarative_stream = self._create_component_from_model( - model.stream, config=config, **kwargs + model.stream, config=config, concurrent=True, **kwargs ) request_option = ( self._create_component_from_model(model.request_option, config=config) @@ -2935,9 +2917,9 @@ def create_simple_retriever( cursor = stream_slicer if isinstance(stream_slicer, DeclarativeCursor) else None if ( - not isinstance(stream_slicer, DatetimeBasedCursor) - or type(stream_slicer) is not DatetimeBasedCursor - ): + not isinstance(stream_slicer, (ConcurrentCursor, ConcurrentPerPartitionCursor)) + or type(stream_slicer) not in {ConcurrentCursor, ConcurrentPerPartitionCursor} + ): # FIXME this condition is probably wrong given the existance of IncrementingCountCursor # Many of the custom component implementations of DatetimeBasedCursor override get_request_params() (or other methods). # Because we're decoupling RequestOptionsProvider from the Cursor, custom components will eventually need to reimplement # their own RequestOptionsProvider. However, right now the existing StreamSlicer/Cursor still can act as the SimpleRetriever's @@ -3340,12 +3322,41 @@ def create_substream_partition_router( def _create_message_repository_substream_wrapper( self, model: ParentStreamConfigModel, config: Config, **kwargs: Any ) -> Any: + # getting the parent state + child_state = self._connector_state_manager.get_stream_state(kwargs["stream_name"], None) + if model.incremental_dependency and child_state: + parent_stream_name = model.stream.name or "" + parent_state = ConcurrentPerPartitionCursor.get_parent_state(child_state, parent_stream_name) + + if model.incremental_dependency and not parent_state: + # there are two migration cases: state value from child stream or from global state + parent_state = ConcurrentPerPartitionCursor.get_global_state(child_state, parent_stream_name) + + if not parent_state and not isinstance(parent_state, dict): + cursor_field = InterpolatedString.create( + model.stream.incremental_sync.cursor_field, + parameters=model.stream.incremental_sync.parameters or {}, + ).eval(config) + cursor_values = child_state.values() + if cursor_values: + parent_state = AirbyteStateMessage( + type=AirbyteStateType.STREAM, + stream=AirbyteStreamState( + stream_descriptor=StreamDescriptor(name=parent_stream_name, namespace=None), + stream_state=AirbyteStateBlob({cursor_field: list(cursor_values)[0]}), + ), + ) + connector_state_manager = ConnectorStateManager([parent_state] if parent_state else []) + else: + connector_state_manager = ConnectorStateManager([]) + substream_factory = ModelToComponentFactory( limit_pages_fetched_per_slice=self._limit_pages_fetched_per_slice, limit_slices_fetched=self._limit_slices_fetched, emit_connector_builder_messages=self._emit_connector_builder_messages, disable_retries=self._disable_retries, disable_cache=self._disable_cache, + connector_state_manager=connector_state_manager, message_repository=LogAppenderMessageRepositoryDecorator( {"airbyte_cdk": {"stream": {"is_substream": True}}, "http": {"is_auxiliary": True}}, self._message_repository, diff --git a/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py index 000beeff9..9db5044f9 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py +++ b/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py @@ -7,7 +7,17 @@ import json import logging from dataclasses import InitVar, dataclass -from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, MutableMapping, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + List, + Mapping, + MutableMapping, + Optional, + TypeVar, + Union, +) import dpath import requests @@ -20,11 +30,52 @@ RequestOption, RequestOptionType, ) +from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream +from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState from airbyte_cdk.utils import AirbyteTracedException -if TYPE_CHECKING: - from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream +T = TypeVar("T") + + +def iterate_with_last_flag(generator: Iterable[T]) -> Iterable[tuple[T, bool]]: + + iterator = iter(generator) + + try: + current = next(iterator) + except StopIteration: + return # Return an empty iterator + + for next_item in iterator: + yield current, False + current = next_item + + yield current, True + + +class InMemoryPartition(Partition): + + def __init__(self, stream_name, _slice): + self._stream_name = stream_name + self._slice = _slice + + def stream_name(self) -> str: + return self._stream_name + + def read(self) -> Iterable[Record]: + yield from [] + + def to_slice(self) -> Optional[Mapping[str, Any]]: + return self._slice + + def __hash__(self) -> int: + if self._slice: + # Convert the slice to a string so that it can be hashed + s = json.dumps(self._slice, sort_keys=True) + return hash((self._name, s)) + else: + return hash(self._name) @dataclass @@ -40,7 +91,7 @@ class ParentStreamConfig: incremental_dependency (bool): Indicates if the parent stream should be read incrementally. """ - stream: "DeclarativeStream" # Parent streams must be DeclarativeStream because we can't know which part of the stream slice is a partition for regular Stream + stream: "DefaultStream" parent_key: Union[InterpolatedString, str] partition_field: Union[InterpolatedString, str] config: Config @@ -176,59 +227,55 @@ def stream_slices(self) -> Iterable[StreamSlice]: for field_path in parent_stream_config.extra_fields ] - # read_stateless() assumes the parent is not concurrent. This is currently okay since the concurrent CDK does - # not support either substreams or RFR, but something that needs to be considered once we do - for parent_record in parent_stream.read_only_records(): - parent_partition = None - # Skip non-records (eg AirbyteLogMessage) - if isinstance(parent_record, AirbyteMessage): - self.logger.warning( - f"Parent stream {parent_stream.name} returns records of type AirbyteMessage. This SubstreamPartitionRouter is not able to checkpoint incremental parent state." - ) - if parent_record.type == MessageType.RECORD: - parent_record = parent_record.record.data # type: ignore[union-attr, assignment] # record is always a Record - else: - continue - elif isinstance(parent_record, Record): + for partition, is_last_slice in iterate_with_last_flag(parent_stream.generate_partitions()): + for parent_record, is_last_record_in_slice in iterate_with_last_flag(partition.read()): + self.logger.warning(f"Parent record is {parent_record}") + + parent_stream.cursor.observe(parent_record) + + # Skip non-records (eg AirbyteLogMessage) parent_partition = ( parent_record.associated_slice.partition if parent_record.associated_slice else {} ) - parent_record = parent_record.data - elif not isinstance(parent_record, Mapping): - # The parent_record should only take the form of a Record, AirbyteMessage, or Mapping. Anything else is invalid - raise AirbyteTracedException( - message=f"Parent stream returned records as invalid type {type(parent_record)}" - ) - try: - partition_value = dpath.get( - parent_record, # type: ignore [arg-type] - parent_field, + record_data = parent_record.data + + try: + partition_value = dpath.get( + record_data, # type: ignore [arg-type] + parent_field, + ) + except KeyError: + # FIXME a log here would go a long way for debugging + continue + + # Add extra fields + extracted_extra_fields = self._extract_extra_fields(record_data, extra_fields) + + if parent_stream_config.lazy_read_pointer: + extracted_extra_fields = { + "child_response": self._extract_child_response( + record_data, + parent_stream_config.lazy_read_pointer, # type: ignore[arg-type] # lazy_read_pointer type handeled in __post_init__ of parent_stream_config + ), + **extracted_extra_fields, + } + + if is_last_record_in_slice: + parent_stream.cursor.close_partition(partition) + #if is_last_slice: + # parent_stream.cursor.ensure_at_least_one_state_emitted() + yield StreamSlice( + partition={ + partition_field: partition_value, + "parent_slice": parent_partition or {}, + }, + cursor_slice={}, + extra_fields=extracted_extra_fields, ) - except KeyError: - continue - - # Add extra fields - extracted_extra_fields = self._extract_extra_fields(parent_record, extra_fields) - - if parent_stream_config.lazy_read_pointer: - extracted_extra_fields = { - "child_response": self._extract_child_response( - parent_record, - parent_stream_config.lazy_read_pointer, # type: ignore[arg-type] # lazy_read_pointer type handeled in __post_init__ of parent_stream_config - ), - **extracted_extra_fields, - } - - yield StreamSlice( - partition={ - partition_field: partition_value, - "parent_slice": parent_partition or {}, - }, - cursor_slice={}, - extra_fields=extracted_extra_fields, - ) + parent_stream.cursor.ensure_at_least_one_state_emitted() + def _extract_child_response( self, parent_record: Mapping[str, Any] | AirbyteMessage, pointer: List[InterpolatedString] @@ -414,7 +461,7 @@ def get_stream_state(self) -> Optional[Mapping[str, StreamState]]: parent_state = {} for parent_config in self.parent_stream_configs: if parent_config.incremental_dependency: - parent_state[parent_config.stream.name] = copy.deepcopy(parent_config.stream.state) + parent_state[parent_config.stream.name] = copy.deepcopy(parent_config.stream.cursor.state) return parent_state @property diff --git a/airbyte_cdk/sources/declarative/retrievers/retriever.py b/airbyte_cdk/sources/declarative/retrievers/retriever.py index 155de5782..3221f81e0 100644 --- a/airbyte_cdk/sources/declarative/retrievers/retriever.py +++ b/airbyte_cdk/sources/declarative/retrievers/retriever.py @@ -5,9 +5,8 @@ from abc import abstractmethod from typing import Any, Iterable, Mapping, Optional -from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import StreamSlice from airbyte_cdk.sources.streams.core import StreamData -from airbyte_cdk.sources.types import StreamState +from airbyte_cdk.sources.types import StreamSlice, StreamState class Retriever: diff --git a/airbyte_cdk/sources/message/repository.py b/airbyte_cdk/sources/message/repository.py index 2fc156e8c..6ebb3e5be 100644 --- a/airbyte_cdk/sources/message/repository.py +++ b/airbyte_cdk/sources/message/repository.py @@ -92,7 +92,8 @@ def log_message(self, level: Level, message_provider: Callable[[], LogMessage]) def consume_queue(self) -> Iterable[AirbyteMessage]: while self._message_queue: - yield self._message_queue.popleft() + x = self._message_queue.popleft() + yield x class LogAppenderMessageRepositoryDecorator(MessageRepository): @@ -107,6 +108,8 @@ def __init__( self._log_level = log_level def emit_message(self, message: AirbyteMessage) -> None: + if message.type == Type.STATE: + return # FIXME this is horribly dumb but allows me to test not emitting state messages. We can probably create another decorator that filters and set this only for substream partition router self._decorated.emit_message(message) def log_message(self, level: Level, message_provider: Callable[[], LogMessage]) -> None: diff --git a/airbyte_cdk/sources/streams/concurrent/cursor.py b/airbyte_cdk/sources/streams/concurrent/cursor.py index 88d15bc8a..17aa8b913 100644 --- a/airbyte_cdk/sources/streams/concurrent/cursor.py +++ b/airbyte_cdk/sources/streams/concurrent/cursor.py @@ -80,7 +80,7 @@ def stream_slices(self) -> Iterable[StreamSlice]: Subclasses can override this method to provide actual behavior. """ yield StreamSlice(partition={}, cursor_slice={}) - + class FinalStateCursor(Cursor): """Cursor that is used to guarantee at least one state message is emitted for a concurrent stream.""" diff --git a/airbyte_cdk/sources/streams/concurrent/default_stream.py b/airbyte_cdk/sources/streams/concurrent/default_stream.py index 7679a1eb6..509eeeeaf 100644 --- a/airbyte_cdk/sources/streams/concurrent/default_stream.py +++ b/airbyte_cdk/sources/streams/concurrent/default_stream.py @@ -25,7 +25,7 @@ def __init__( json_schema: Mapping[str, Any], availability_strategy: AbstractAvailabilityStrategy, primary_key: List[str], - cursor_field: Optional[str], + cursor_field: Optional[str], # FIXME can't we deduce this from self._cursor? logger: Logger, cursor: Cursor, namespace: Optional[str] = None, diff --git a/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py b/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py index 042a430aa..88c0d7923 100644 --- a/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py +++ b/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py @@ -295,8 +295,9 @@ } STREAM_NAME = "post_comment_votes" +START_DATE = "2024-01-01T00:00:01Z" CONFIG = { - "start_date": "2024-01-01T00:00:01Z", + "start_date": START_DATE, "credentials": {"email": "email", "api_token": "api_token"}, } @@ -388,7 +389,6 @@ def _run_read( # Existing Constants for Dates -START_DATE = "2024-01-01T00:00:01Z" # Start of the sync POST_1_UPDATED_AT = "2024-01-30T00:00:00Z" # Latest update date for post 1 POST_2_UPDATED_AT = "2024-01-29T00:00:00Z" # Latest update date for post 2 POST_3_UPDATED_AT = "2024-01-28T00:00:00Z" # Latest update date for post 3 @@ -675,7 +675,7 @@ def _run_read( "id": 10, "parent_slice": {"id": 1, "parent_slice": {}}, }, - "cursor": {"created_at": INITIAL_STATE_PARTITION_10_CURSOR_TIMESTAMP}, + "cursor": {"created_at": INITIAL_STATE_PARTITION_10_CURSOR_TIMESTAMP}, # that's a very fucked up state as it has different format for different partition }, { "partition": { @@ -742,7 +742,7 @@ def _run_read( ), ], ) -def test_incremental_parent_state_no_incremental_dependency( +def test_incremental_parent_state_no_incremental_dependency( #posts/1/comments/10/votes test_name, manifest, mock_requests, expected_records, initial_state, expected_state ): """ @@ -840,6 +840,9 @@ def run_incremental_parent_state_test( # For each intermediate state, perform another read starting from that state for state, records_before_state in intermediate_states[:-1]: + import logging + logger = logging.getLogger() + logger.warning(f"Running tmp state: {state.stream.stream_state.__dict__}") output_intermediate = _run_read(manifest, CONFIG, STREAM_NAME, [state]) records_from_state = [r.record.data for r in output_intermediate.records] @@ -869,9 +872,14 @@ def run_incremental_parent_state_test( # Assert that the final state matches the expected state for all runs for i, final_state in enumerate(final_states): - assert ( - final_state in expected_states - ), f"Final state mismatch at run {i + 1}. Expected {expected_states}, got {final_state}" + if len(expected_states) == 1: + assert ( + final_state == expected_states[0] + ), f"Final state mismatch at run {i + 1}. Expected {expected_states}, got {final_state}" + else: + assert ( + final_state in expected_states + ), f"Final state mismatch at run {i + 1}. Expected {expected_states}, got {final_state}" @pytest.mark.parametrize( @@ -900,6 +908,11 @@ def run_incremental_parent_state_test( f"https://api.example.com/community/posts?per_page=100&start_time={PARENT_POSTS_CURSOR}&page=2", {"posts": [{"id": 3, "updated_at": POST_3_UPDATED_AT}]}, ), + # Once state is updated, we might fetch the most recent record + ( + f"https://api.example.com/community/posts?per_page=100&start_time={POST_1_UPDATED_AT}", + {"posts": [{"id": 1, "updated_at": POST_1_UPDATED_AT}]}, + ), # Fetch the first page of comments for post 1 ( "https://api.example.com/community/posts/1/comments?per_page=100", @@ -1134,6 +1147,9 @@ def run_incremental_parent_state_test( "posts": {"updated_at": POST_1_UPDATED_AT} }, # post 1 is the latest "lookback_window": 1, + "state": { + "updated_at": "2024-01-25T00:00:00Z" + }, "states": [ { "partition": {"id": 1, "parent_slice": {}}, @@ -1561,12 +1577,12 @@ def test_incremental_parent_state_migration( "states": [ { "partition": {"id": 1, "parent_slice": {}}, - "cursor": {"updated_at": PARENT_COMMENT_CURSOR_PARTITION_1}, + "cursor": {"updated_at": START_DATE}, # interesting thing is that the concurrent cursor takes the max between the state and the start } ], - "state": {}, "use_global_cursor": False, "parent_state": {"posts": {"updated_at": PARENT_POSTS_CURSOR}}, + "lookback_window": 1, } }, "states": [ diff --git a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py index f628eeb3b..dc5379be4 100644 --- a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py +++ b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py @@ -5,13 +5,14 @@ # mypy: ignore-errors from datetime import datetime, timedelta, timezone from typing import Any, Iterable, Mapping +from unittest.mock import Mock import freezegun import pytest import requests from pydantic.v1 import ValidationError -from airbyte_cdk import AirbyteTracedException +from airbyte_cdk import AirbyteTracedException, InMemoryMessageRepository from airbyte_cdk.models import ( AirbyteStateBlob, AirbyteStateMessage, @@ -42,6 +43,7 @@ ClientSideIncrementalRecordFilterDecorator, ) from airbyte_cdk.sources.declarative.incremental import ( + ConcurrentPerPartitionCursor, CursorFactory, DatetimeBasedCursor, PerPartitionCursor, @@ -164,7 +166,7 @@ MonthClampingStrategy, WeekClampingStrategy, ) -from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor +from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, CursorField from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import ( CustomFormatConcurrentStreamStateConverter, ) @@ -185,7 +187,7 @@ transformer = ManifestComponentTransformer() -input_config = {"apikey": "verysecrettoken", "repos": ["airbyte", "airbyte-cloud"]} +input_config = {"apikey": "verysecrettoken", "repos": ["airbyte", "airbyte-cloud"], "start_time": "2025-01-01T00:00:00.000+00:00"} def test_create_check_stream(): @@ -904,18 +906,12 @@ def test_stream_with_incremental_and_retriever_with_partition_router(): assert isinstance(stream, DeclarativeStream) assert isinstance(stream.retriever, SimpleRetriever) - assert isinstance(stream.retriever.stream_slicer, PerPartitionWithGlobalCursor) + assert isinstance(stream.retriever.stream_slicer, ConcurrentPerPartitionCursor) datetime_stream_slicer = ( - stream.retriever.stream_slicer._per_partition_cursor._cursor_factory.create() + stream.retriever.stream_slicer._cursor_factory.create({}, None) ) - assert isinstance(datetime_stream_slicer, DatetimeBasedCursor) - assert isinstance(datetime_stream_slicer._start_datetime, MinMaxDatetime) - assert datetime_stream_slicer._start_datetime.datetime.string == "{{ config['start_time'] }}" - assert isinstance(datetime_stream_slicer._end_datetime, MinMaxDatetime) - assert datetime_stream_slicer._end_datetime.datetime.string == "{{ config['end_time'] }}" - assert datetime_stream_slicer.step == "P10D" - assert datetime_stream_slicer.cursor_field.string == "created" + assert isinstance(datetime_stream_slicer, ConcurrentCursor) list_stream_slicer = stream.retriever.stream_slicer._partition_router assert isinstance(list_stream_slicer, ListPartitionRouter) @@ -1032,8 +1028,7 @@ def test_resumable_full_refresh_stream(): assert isinstance(stream.retriever.record_selector, RecordSelector) - assert isinstance(stream.retriever.stream_slicer, ResumableFullRefreshCursor) - assert isinstance(stream.retriever.cursor, ResumableFullRefreshCursor) + assert isinstance(stream.retriever.stream_slicer, SinglePartitionRouter) assert isinstance(stream.retriever.paginator, DefaultPaginator) assert isinstance(stream.retriever.paginator.decoder, PaginationDecoderDecorator) @@ -1292,7 +1287,7 @@ def test_client_side_incremental_with_partition_router(): assert stream.retriever.record_selector.transform_before_filtering == True assert isinstance( stream.retriever.record_selector.record_filter._cursor, - PerPartitionWithGlobalCursor, + ConcurrentPerPartitionCursor, ) @@ -1864,7 +1859,7 @@ def test_create_default_paginator(): "subcomponent_field_with_hint", DpathExtractor( field_path=[], - config={"apikey": "verysecrettoken", "repos": ["airbyte", "airbyte-cloud"]}, + config={"apikey": "verysecrettoken", "repos": ["airbyte", "airbyte-cloud"], "start_time": "2025-01-01T00:00:00.000+00:00"}, decoder=JsonDecoder(parameters={}), parameters={}, ), @@ -1880,7 +1875,7 @@ def test_create_default_paginator(): "subcomponent_field_with_hint", DpathExtractor( field_path=[], - config={"apikey": "verysecrettoken", "repos": ["airbyte", "airbyte-cloud"]}, + config={"apikey": "verysecrettoken", "repos": ["airbyte", "airbyte-cloud"], "start_time": "2025-01-01T00:00:00.000+00:00"}, parameters={}, ), None, @@ -1968,11 +1963,11 @@ def test_create_default_paginator(): DefaultPaginator( pagination_strategy=OffsetIncrement( page_size=10, - config={"apikey": "verysecrettoken", "repos": ["airbyte", "airbyte-cloud"]}, + config={"apikey": "verysecrettoken", "repos": ["airbyte", "airbyte-cloud"], "start_time": "2025-01-01T00:00:00.000+00:00"}, parameters={}, ), url_base="https://physical_100.com", - config={"apikey": "verysecrettoken", "repos": ["airbyte", "airbyte-cloud"]}, + config={"apikey": "verysecrettoken", "repos": ["airbyte", "airbyte-cloud"], "start_time": "2025-01-01T00:00:00.000+00:00"}, parameters={"decoder": {"type": "JsonDecoder"}}, ), None, @@ -2448,7 +2443,7 @@ def test_default_schema_loader(self): "cursor_granularity": "PT0.000001S", }, None, - DatetimeBasedCursor, + ConcurrentCursor, id="test_create_simple_retriever_with_incremental", ), pytest.param( @@ -2458,7 +2453,7 @@ def test_default_schema_loader(self): "values": "{{config['repos']}}", "cursor_field": "a_key", }, - PerPartitionCursor, + ListPartitionRouter, id="test_create_simple_retriever_with_partition_router", ), pytest.param( @@ -2476,7 +2471,7 @@ def test_default_schema_loader(self): "values": "{{config['repos']}}", "cursor_field": "a_key", }, - PerPartitionWithGlobalCursor, + ConcurrentPerPartitionCursor, id="test_create_simple_retriever_with_incremental_and_partition_router", ), pytest.param( @@ -2501,7 +2496,7 @@ def test_default_schema_loader(self): "cursor_field": "b_key", }, ], - PerPartitionWithGlobalCursor, + ConcurrentPerPartitionCursor, id="test_create_simple_retriever_with_partition_routers_multiple_components", ), pytest.param( @@ -2548,7 +2543,7 @@ def test_merge_incremental_and_partition_router(incremental, partition_router, e assert isinstance(stream.retriever.stream_slicer, expected_type) if incremental and partition_router: - assert isinstance(stream.retriever.stream_slicer, PerPartitionWithGlobalCursor) + assert isinstance(stream.retriever.stream_slicer, ConcurrentPerPartitionCursor) if isinstance(partition_router, list) and len(partition_router) > 1: assert isinstance( stream.retriever.stream_slicer._partition_router, CartesianProductStreamSlicer @@ -2556,9 +2551,6 @@ def test_merge_incremental_and_partition_router(incremental, partition_router, e assert len(stream.retriever.stream_slicer._partition_router.stream_slicers) == len( partition_router ) - elif partition_router and isinstance(partition_router, list) and len(partition_router) > 1: - assert isinstance(stream.retriever.stream_slicer, PerPartitionWithGlobalCursor) - assert len(stream.retriever.stream_slicer.stream_slicerS) == len(partition_router) def test_simple_retriever_emit_log_messages(): @@ -2885,15 +2877,17 @@ def test_use_request_options_provider_for_datetime_based_cursor(): }, } - datetime_based_cursor = DatetimeBasedCursor( - start_datetime=MinMaxDatetime(datetime="{{ config.start_time }}", parameters={}), - step="P5D", - cursor_field="updated_at", - datetime_format="%Y-%m-%dT%H:%M:%S.%f%z", - cursor_granularity="PT1S", - is_compare_strictly=True, - config=config, - parameters={}, + datetime_based_cursor = ConcurrentCursor( + stream_name="a_stream_name", + stream_namespace="a_stream_namespace", + stream_state=None, + message_repository=InMemoryMessageRepository(), + connector_state_manager=Mock(), + connector_state_converter=Mock(), + cursor_field=CursorField("updated_at"), + slice_boundary_fields=("start","end"), + start=datetime.now(), + end_provider=datetime.now(), ) datetime_based_request_options_provider = DatetimeBasedRequestOptionsProvider( @@ -2927,8 +2921,7 @@ def test_use_request_options_provider_for_datetime_based_cursor(): assert retriever.primary_key == "id" assert retriever.name == "Test" - assert isinstance(retriever.cursor, DatetimeBasedCursor) - assert isinstance(retriever.stream_slicer, DatetimeBasedCursor) + assert isinstance(retriever.stream_slicer, ConcurrentCursor) assert isinstance(retriever.request_option_provider, DatetimeBasedRequestOptionsProvider) assert ( diff --git a/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py b/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py index b0dcd272c..793abb789 100644 --- a/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py +++ b/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py @@ -145,12 +145,6 @@ def __hash__(self) -> int: else: return hash(self._name) - def close(self) -> None: - self._is_closed = True - - def is_closed(self) -> bool: - return self._is_closed - class ConcurrentSourceBuilder(SourceBuilder[ConcurrentCdkSource]): def __init__(self):