diff --git a/airbyte_cdk/connector_builder/test_reader/reader.py b/airbyte_cdk/connector_builder/test_reader/reader.py index 5c16798a2..e7399f3f6 100644 --- a/airbyte_cdk/connector_builder/test_reader/reader.py +++ b/airbyte_cdk/connector_builder/test_reader/reader.py @@ -120,7 +120,11 @@ def run_test_read( deprecation_warnings: List[LogMessage] = source.deprecation_warnings() schema_inferrer = SchemaInferrer( - self._pk_to_nested_and_composite_field(stream.primary_key) if stream else None, + self._pk_to_nested_and_composite_field( + stream.primary_key if hasattr(stream, "primary_key") else stream._primary_key # type: ignore # We are accessing the private property here as the primary key is not exposed. We should either expose it or use `as_airbyte_stream` to retrieve it as this is the "official" way where it is exposed in the Airbyte protocol + ) + if stream + else None, self._cursor_field_to_nested_and_composite_field(stream.cursor_field) if stream else None, diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py index 2bcc4b8c9..720934a11 100644 --- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py +++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py @@ -3,7 +3,7 @@ # import logging -from typing import Any, Generic, Iterator, List, Mapping, MutableMapping, Optional, Tuple +from typing import Any, Generic, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Union from airbyte_cdk.models import ( AirbyteCatalog, @@ -15,10 +15,6 @@ from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.declarative.concurrency_level import ConcurrencyLevel from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream -from airbyte_cdk.sources.declarative.extractors import RecordSelector -from airbyte_cdk.sources.declarative.extractors.record_filter import ( - ClientSideIncrementalRecordFilterDecorator, -) from airbyte_cdk.sources.declarative.incremental import ( ConcurrentPerPartitionCursor, GlobalSubstreamCursor, @@ -28,7 +24,6 @@ PerPartitionWithGlobalCursor, ) from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource -from airbyte_cdk.sources.declarative.models import FileUploader from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( ConcurrencyLevel as ConcurrencyLevelModel, ) @@ -84,7 +79,6 @@ def __init__( # incremental streams running in full refresh. component_factory = component_factory or ModelToComponentFactory( emit_connector_builder_messages=emit_connector_builder_messages, - disable_resumable_full_refresh=True, connector_state_manager=self._connector_state_manager, max_concurrent_async_job_count=source_config.get("max_concurrent_async_job_count"), ) @@ -180,7 +174,7 @@ def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> Airbyte ] ) - def streams(self, config: Mapping[str, Any]) -> List[Stream]: + def streams(self, config: Mapping[str, Any]) -> List[Union[Stream, AbstractStream]]: # type: ignore # we are migrating away from the AbstractSource and are expecting that this will only be called by ConcurrentDeclarativeSource or the Connector Builder """ The `streams` method is used as part of the AbstractSource in the following cases: * ConcurrentDeclarativeSource.check -> ManifestDeclarativeSource.check -> AbstractSource.check -> DeclarativeSource.check_connection -> CheckStream.check_connection -> streams @@ -210,6 +204,10 @@ def _group_streams( # these legacy Python streams the way we do low-code streams to determine if they are concurrent compatible, # so we need to treat them as synchronous + if isinstance(declarative_stream, AbstractStream): + concurrent_streams.append(declarative_stream) + continue + supports_file_transfer = ( isinstance(declarative_stream, DeclarativeStream) and "file_uploader" in name_to_stream_mapping[declarative_stream.name] @@ -278,10 +276,10 @@ def _group_streams( partition_generator = StreamSlicerPartitionGenerator( partition_factory=DeclarativePartitionFactory( - declarative_stream.name, - declarative_stream.get_json_schema(), - retriever, - self.message_repository, + stream_name=declarative_stream.name, + schema_loader=declarative_stream._schema_loader, # type: ignore # We are accessing the private property but the public one is optional and we will remove this code soonish + retriever=retriever, + message_repository=self.message_repository, ), stream_slicer=declarative_stream.retriever.stream_slicer, ) @@ -309,10 +307,10 @@ def _group_streams( ) partition_generator = StreamSlicerPartitionGenerator( partition_factory=DeclarativePartitionFactory( - declarative_stream.name, - declarative_stream.get_json_schema(), - retriever, - self.message_repository, + stream_name=declarative_stream.name, + schema_loader=declarative_stream._schema_loader, # type: ignore # We are accessing the private property but the public one is optional and we will remove this code soonish + retriever=retriever, + message_repository=self.message_repository, ), stream_slicer=cursor, ) @@ -339,10 +337,10 @@ def _group_streams( ) and hasattr(declarative_stream.retriever, "stream_slicer"): partition_generator = StreamSlicerPartitionGenerator( DeclarativePartitionFactory( - declarative_stream.name, - declarative_stream.get_json_schema(), - declarative_stream.retriever, - self.message_repository, + stream_name=declarative_stream.name, + schema_loader=declarative_stream._schema_loader, # type: ignore # We are accessing the private property but the public one is optional and we will remove this code soonish + retriever=declarative_stream.retriever, + message_repository=self.message_repository, ), declarative_stream.retriever.stream_slicer, ) @@ -399,10 +397,10 @@ def _group_streams( partition_generator = StreamSlicerPartitionGenerator( DeclarativePartitionFactory( - declarative_stream.name, - declarative_stream.get_json_schema(), - retriever, - self.message_repository, + stream_name=declarative_stream.name, + schema_loader=declarative_stream._schema_loader, # type: ignore # We are accessing the private property but the public one is optional and we will remove this code soonish + retriever=retriever, + message_repository=self.message_repository, ), perpartition_cursor, ) diff --git a/airbyte_cdk/sources/declarative/manifest_declarative_source.py b/airbyte_cdk/sources/declarative/manifest_declarative_source.py index e962f3813..b1736f371 100644 --- a/airbyte_cdk/sources/declarative/manifest_declarative_source.py +++ b/airbyte_cdk/sources/declarative/manifest_declarative_source.py @@ -8,7 +8,7 @@ from copy import deepcopy from importlib import metadata from types import ModuleType -from typing import Any, Dict, Iterator, List, Mapping, Optional, Set +from typing import Any, Dict, Iterator, List, Mapping, Optional, Set, Union import orjson import yaml @@ -66,6 +66,7 @@ from airbyte_cdk.sources.declarative.resolvers import COMPONENTS_RESOLVER_TYPE_MAPPING from airbyte_cdk.sources.declarative.spec.spec import Spec from airbyte_cdk.sources.message import MessageRepository +from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream from airbyte_cdk.sources.streams.core import Stream from airbyte_cdk.sources.types import Config, ConnectionDefinition from airbyte_cdk.sources.utils.slice_logger import ( @@ -297,7 +298,12 @@ def connection_checker(self) -> ConnectionChecker: f"Expected to generate a ConnectionChecker component, but received {check_stream.__class__}" ) - def streams(self, config: Mapping[str, Any]) -> List[Stream]: + def streams(self, config: Mapping[str, Any]) -> List[Union[Stream, AbstractStream]]: # type: ignore # we are migrating away from the AbstractSource and are expecting that this will only be called by ConcurrentDeclarativeSource or the Connector Builder + """ + As a migration step, this method will return both legacy stream (Stream) and concurrent stream (AbstractStream). + Once the migration is done, we can probably have this method throw "not implemented" as we figure out how to + fully decouple this from the AbstractSource. + """ if self._spec_component: self._spec_component.validate_config(config) diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index e1d07fdc2..1ec8136a4 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -7,6 +7,7 @@ import datetime import importlib import inspect +import logging import re from functools import partial from typing import ( @@ -543,6 +544,10 @@ StreamSlicer, StreamSlicerTestReadDecorator, ) +from airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator import ( + DeclarativePartitionFactory, + StreamSlicerPartitionGenerator, +) from airbyte_cdk.sources.declarative.transformations import ( AddFields, RecordTransformation, @@ -594,6 +599,7 @@ Rate, UnlimitedCallRatePolicy, ) +from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream from airbyte_cdk.sources.streams.concurrent.clamping import ( ClampingEndProvider, ClampingStrategy, @@ -603,7 +609,14 @@ WeekClampingStrategy, Weekday, ) -from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, Cursor, CursorField +from airbyte_cdk.sources.streams.concurrent.cursor import ( + ConcurrentCursor, + Cursor, + CursorField, + FinalStateCursor, +) +from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream +from airbyte_cdk.sources.streams.concurrent.helpers import get_primary_key_from_stream from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import ( CustomFormatConcurrentStreamStateConverter, DateTimeStreamStateConverter, @@ -633,7 +646,6 @@ def __init__( emit_connector_builder_messages: bool = False, disable_retries: bool = False, disable_cache: bool = False, - disable_resumable_full_refresh: bool = False, message_repository: Optional[MessageRepository] = None, connector_state_manager: Optional[ConnectorStateManager] = None, max_concurrent_async_job_count: Optional[int] = None, @@ -644,7 +656,6 @@ def __init__( self._emit_connector_builder_messages = emit_connector_builder_messages self._disable_retries = disable_retries self._disable_cache = disable_cache - self._disable_resumable_full_refresh = disable_resumable_full_refresh self._message_repository = message_repository or InMemoryMessageRepository( self._evaluate_log_level(emit_connector_builder_messages) ) @@ -1920,8 +1931,8 @@ def create_datetime_based_cursor( ) def create_declarative_stream( - self, model: DeclarativeStreamModel, config: Config, **kwargs: Any - ) -> DeclarativeStream: + self, model: DeclarativeStreamModel, config: Config, is_parent: bool = False, **kwargs: Any + ) -> Union[DeclarativeStream, AbstractStream]: # When constructing a declarative stream, we assemble the incremental_sync component and retriever's partition_router field # components if they exist into a single CartesianProductStreamSlicer. This is then passed back as an argument when constructing the # Retriever. This is done in the declarative stream not the retriever to support custom retrievers. The custom create methods in @@ -2027,15 +2038,6 @@ def create_declarative_stream( file_uploader=file_uploader, incremental_sync=model.incremental_sync, ) - cursor_field = model.incremental_sync.cursor_field if model.incremental_sync else None - - if model.state_migrations: - state_transformations = [ - self._create_component_from_model(state_migration, config, declarative_stream=model) - for state_migration in model.state_migrations - ] - else: - state_transformations = [] schema_loader: Union[ CompositeSchemaLoader, @@ -2063,6 +2065,56 @@ def create_declarative_stream( options["name"] = model.name schema_loader = DefaultSchemaLoader(config=config, parameters=options) + if ( + isinstance(combined_slicers, PartitionRouter) + and not self._emit_connector_builder_messages + and not is_parent + ): + # We are starting to migrate streams to instantiate directly the DefaultStream instead of instantiating the + # DeclarativeStream and assembling the DefaultStream from that. The plan is the following: + # * Streams without partition router nor cursors and streams with only partition router. This is the `isinstance(combined_slicers, PartitionRouter)` condition as the first kind with have a SinglePartitionRouter + # * Streams without partition router but with cursor + # * Streams with both partition router and cursor + # We specifically exclude parent streams here because SubstreamPartitionRouter has not been updated yet + # We specifically exclude Connector Builder stuff for now as Brian is working on this anyway + stream_name = model.name or "" + partition_generator = StreamSlicerPartitionGenerator( + DeclarativePartitionFactory( + stream_name, + schema_loader, + retriever, + self._message_repository, + ), + stream_slicer=cast( + StreamSlicer, + StreamSlicerTestReadDecorator( + wrapped_slicer=combined_slicers, + maximum_number_of_slices=self._limit_slices_fetched or 5, + ), + ), + ) + return DefaultStream( + partition_generator=partition_generator, + name=stream_name, + json_schema=schema_loader.get_json_schema, + primary_key=get_primary_key_from_stream(primary_key), + cursor_field=None, + # FIXME we should have the cursor field has part of the interface of cursor + logger=logging.getLogger(f"airbyte.{stream_name}"), + # FIXME this is a breaking change compared to the old implementation, + cursor=FinalStateCursor(stream_name, None, self._message_repository), + supports_file_transfer=hasattr(model, "file_uploader") + and bool(model.file_uploader), + ) + + cursor_field = model.incremental_sync.cursor_field if model.incremental_sync else None + if model.state_migrations: + state_transformations = [ + self._create_component_from_model(state_migration, config, declarative_stream=model) + for state_migration in model.state_migrations + ] + else: + state_transformations = [] return DeclarativeStream( name=model.name or "", primary_key=primary_key, @@ -2083,7 +2135,7 @@ def _build_stream_slicer_from_partition_router( ], config: Config, stream_name: Optional[str] = None, - ) -> Optional[PartitionRouter]: + ) -> PartitionRouter: if ( hasattr(model, "partition_router") and isinstance(model, SimpleRetrieverModel | AsyncRetrieverModel) @@ -2104,7 +2156,7 @@ def _build_stream_slicer_from_partition_router( return self._create_component_from_model( # type: ignore[no-any-return] # Will be created PartitionRouter as stream_slicer_model is model.partition_router model=stream_slicer_model, config=config, stream_name=stream_name or "" ) - return None + return SinglePartitionRouter(parameters={}) def _build_incremental_cursor( self, @@ -2121,7 +2173,9 @@ def _build_incremental_cursor( else [] ) - if model.incremental_sync and stream_slicer: + if model.incremental_sync and ( + stream_slicer and not isinstance(stream_slicer, SinglePartitionRouter) + ): if model.retriever.type == "AsyncRetriever": stream_name = model.name or "" stream_namespace = None @@ -2194,7 +2248,11 @@ def _build_concurrent_cursor( else: state_transformations = [] - if model.incremental_sync and stream_slicer: + if ( + model.incremental_sync + and stream_slicer + and not isinstance(stream_slicer, SinglePartitionRouter) + ): return self.create_concurrent_cursor_from_perpartition_cursor( # type: ignore # This is a known issue that we are creating and returning a ConcurrentCursor which does not technically implement the (low-code) StreamSlicer. However, (low-code) StreamSlicer and ConcurrentCursor both implement StreamSlicer.stream_slices() which is the primary method needed for checkpointing state_manager=self._connector_state_manager, model_type=DatetimeBasedCursorModel, @@ -2233,28 +2291,6 @@ def _build_concurrent_cursor( ) return None - def _build_resumable_cursor( - self, - model: Union[ - AsyncRetrieverModel, - CustomRetrieverModel, - SimpleRetrieverModel, - ], - stream_slicer: Optional[PartitionRouter], - ) -> Optional[StreamSlicer]: - if hasattr(model, "paginator") and model.paginator and not stream_slicer: - # For the regular Full-Refresh streams, we use the high level `ResumableFullRefreshCursor` - return ResumableFullRefreshCursor(parameters={}) - elif stream_slicer: - # For the Full-Refresh sub-streams, we use the nested `ChildPartitionResumableFullRefreshCursor` - return PerPartitionCursor( - cursor_factory=CursorFactory( - create_function=partial(ChildPartitionResumableFullRefreshCursor, {}) - ), - partition_router=stream_slicer, - ) - return None - def _merge_stream_slicers( self, model: DeclarativeStreamModel, config: Config ) -> Optional[StreamSlicer]: @@ -2291,11 +2327,7 @@ def _merge_stream_slicers( if model.incremental_sync: return self._build_incremental_cursor(model, stream_slicer, config) - return ( - stream_slicer - if self._disable_resumable_full_refresh - else self._build_resumable_cursor(retriever_model, stream_slicer) - ) + return stream_slicer def create_default_error_handler( self, model: DefaultErrorHandlerModel, config: Config, **kwargs: Any @@ -2577,9 +2609,6 @@ def create_schema_type_identifier( def create_dynamic_schema_loader( self, model: DynamicSchemaLoaderModel, config: Config, **kwargs: Any ) -> DynamicSchemaLoader: - stream_slicer = self._build_stream_slicer_from_partition_router(model.retriever, config) - combined_slicers = self._build_resumable_cursor(model.retriever, stream_slicer) - schema_transformations = [] if model.schema_transformations: for transformation_model in model.schema_transformations: @@ -2592,7 +2621,7 @@ def create_dynamic_schema_loader( config=config, name=name, primary_key=None, - stream_slicer=combined_slicers, + stream_slicer=self._build_stream_slicer_from_partition_router(model.retriever, config), transformations=[], use_cache=True, log_formatter=( @@ -2945,7 +2974,10 @@ def create_parent_stream_config( self, model: ParentStreamConfigModel, config: Config, **kwargs: Any ) -> ParentStreamConfig: declarative_stream = self._create_component_from_model( - model.stream, config=config, **kwargs + model.stream, + config=config, + is_parent=True, + **kwargs, ) request_option = ( self._create_component_from_model(model.request_option, config=config) @@ -3855,15 +3887,12 @@ def create_components_mapping_definition( def create_http_components_resolver( self, model: HttpComponentsResolverModel, config: Config, stream_name: Optional[str] = None ) -> Any: - stream_slicer = self._build_stream_slicer_from_partition_router(model.retriever, config) - combined_slicers = self._build_resumable_cursor(model.retriever, stream_slicer) - retriever = self._create_component_from_model( model=model.retriever, config=config, name=f"{stream_name if stream_name else '__http_components_resolver'}", primary_key=None, - stream_slicer=stream_slicer if stream_slicer else combined_slicers, + stream_slicer=self._build_stream_slicer_from_partition_router(model.retriever, config), transformations=[], ) diff --git a/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py b/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py index 94ee03a56..a7ce26143 100644 --- a/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py +++ b/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py @@ -3,6 +3,7 @@ from typing import Any, Iterable, Mapping, Optional from airbyte_cdk.sources.declarative.retrievers import Retriever +from airbyte_cdk.sources.declarative.schema import SchemaLoader from airbyte_cdk.sources.message import MessageRepository from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator @@ -11,11 +12,23 @@ from airbyte_cdk.utils.slice_hasher import SliceHasher +class SchemaLoaderCachingDecorator(SchemaLoader): + def __init__(self, schema_loader: SchemaLoader): + self._decorated = schema_loader + self._loaded_schema: Optional[Mapping[str, Any]] = None + + def get_json_schema(self) -> Mapping[str, Any]: + if self._loaded_schema is None: + self._loaded_schema = self._decorated.get_json_schema() + + return self._loaded_schema # type: ignore # at that point, we assume the schema will be populated + + class DeclarativePartitionFactory: def __init__( self, stream_name: str, - json_schema: Mapping[str, Any], + schema_loader: SchemaLoader, retriever: Retriever, message_repository: MessageRepository, ) -> None: @@ -25,17 +38,17 @@ def __init__( In order to avoid these problems, we will create one retriever per thread which should make the processing thread-safe. """ self._stream_name = stream_name - self._json_schema = json_schema + self._schema_loader = SchemaLoaderCachingDecorator(schema_loader) self._retriever = retriever self._message_repository = message_repository def create(self, stream_slice: StreamSlice) -> Partition: return DeclarativePartition( - self._stream_name, - self._json_schema, - self._retriever, - self._message_repository, - stream_slice, + stream_name=self._stream_name, + schema_loader=self._schema_loader, + retriever=self._retriever, + message_repository=self._message_repository, + stream_slice=stream_slice, ) @@ -43,20 +56,22 @@ class DeclarativePartition(Partition): def __init__( self, stream_name: str, - json_schema: Mapping[str, Any], + schema_loader: SchemaLoader, retriever: Retriever, message_repository: MessageRepository, stream_slice: StreamSlice, ): self._stream_name = stream_name - self._json_schema = json_schema + self._schema_loader = schema_loader self._retriever = retriever self._message_repository = message_repository self._stream_slice = stream_slice self._hash = SliceHasher.hash(self._stream_name, self._stream_slice) def read(self) -> Iterable[Record]: - for stream_data in self._retriever.read_records(self._json_schema, self._stream_slice): + for stream_data in self._retriever.read_records( + self._schema_loader.get_json_schema(), self._stream_slice + ): if isinstance(stream_data, Mapping): record = ( stream_data diff --git a/airbyte_cdk/sources/streams/concurrent/adapters.py b/airbyte_cdk/sources/streams/concurrent/adapters.py index 949f0545b..c1dea49de 100644 --- a/airbyte_cdk/sources/streams/concurrent/adapters.py +++ b/airbyte_cdk/sources/streams/concurrent/adapters.py @@ -6,7 +6,7 @@ import json import logging from functools import lru_cache -from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union +from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union from typing_extensions import deprecated @@ -196,6 +196,7 @@ def cursor_field(self) -> Union[str, List[str]]: def cursor(self) -> Optional[Cursor]: # type: ignore[override] # StreamFaced expects to use only airbyte_cdk.sources.streams.concurrent.cursor.Cursor return self._cursor + # FIXME the lru_cache seems to be mostly there because of typing issue @lru_cache(maxsize=None) def get_json_schema(self) -> Mapping[str, Any]: return self._abstract_stream.get_json_schema() diff --git a/airbyte_cdk/sources/streams/concurrent/default_stream.py b/airbyte_cdk/sources/streams/concurrent/default_stream.py index 10f04e6ba..ca227fd50 100644 --- a/airbyte_cdk/sources/streams/concurrent/default_stream.py +++ b/airbyte_cdk/sources/streams/concurrent/default_stream.py @@ -2,9 +2,8 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from functools import lru_cache from logging import Logger -from typing import Any, Iterable, List, Mapping, Optional +from typing import Any, Callable, Iterable, List, Mapping, Optional, Union from airbyte_cdk.models import AirbyteStream, SyncMode from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream @@ -20,7 +19,7 @@ def __init__( self, partition_generator: PartitionGenerator, name: str, - json_schema: Mapping[str, Any], + json_schema: Union[Mapping[str, Any], Callable[[], Mapping[str, Any]]], primary_key: List[str], cursor_field: Optional[str], logger: Logger, @@ -53,14 +52,13 @@ def namespace(self) -> Optional[str]: def cursor_field(self) -> Optional[str]: return self._cursor_field - @lru_cache(maxsize=None) def get_json_schema(self) -> Mapping[str, Any]: - return self._json_schema + return self._json_schema() if callable(self._json_schema) else self._json_schema def as_airbyte_stream(self) -> AirbyteStream: stream = AirbyteStream( name=self.name, - json_schema=dict(self._json_schema), + json_schema=dict(self.get_json_schema()), supported_sync_modes=[SyncMode.full_refresh], is_resumable=False, is_file_based=self._supports_file_transfer, diff --git a/unit_tests/connector_builder/test_connector_builder_handler.py b/unit_tests/connector_builder/test_connector_builder_handler.py index 2587fb95a..d6f1bf2d6 100644 --- a/unit_tests/connector_builder/test_connector_builder_handler.py +++ b/unit_tests/connector_builder/test_connector_builder_handler.py @@ -7,7 +7,7 @@ import json import logging import os -from typing import List, Literal +from typing import List, Literal, Union from unittest import mock from unittest.mock import MagicMock, patch @@ -56,10 +56,14 @@ Type, ) from airbyte_cdk.models import Type as MessageType +from airbyte_cdk.sources.declarative.concurrent_declarative_source import ( + ConcurrentDeclarativeSource, +) from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever from airbyte_cdk.sources.declarative.stream_slicers import StreamSlicerTestReadDecorator +from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets, update_secrets from unit_tests.connector_builder.utils import create_configured_catalog @@ -440,6 +444,14 @@ } +def get_retriever(stream: Union[DeclarativeStream, DefaultStream]): + return ( + stream.retriever + if isinstance(stream, DeclarativeStream) + else stream._stream_partition_generator._partition_factory._retriever + ) + + @pytest.fixture def valid_resolve_manifest_config_file(tmp_path): config_file = tmp_path / "config.json" @@ -780,7 +792,13 @@ def test_config_update() -> None: "client_secret": "a client secret", "refresh_token": "a refresh token", } - source = ManifestDeclarativeSource(source_config=manifest) + source = ConcurrentDeclarativeSource( + catalog=None, + config=config, + state=None, + source_config=manifest, + emit_connector_builder_messages=True, + ) refresh_request_response = { "access_token": "an updated access token", @@ -1117,8 +1135,9 @@ def test_read_source(mock_http_stream): streams = source.streams(config) for s in streams: - assert isinstance(s.retriever, SimpleRetriever) - assert isinstance(s.retriever.stream_slicer, StreamSlicerTestReadDecorator) + retriever = get_retriever(s) + assert isinstance(retriever, SimpleRetriever) + assert isinstance(retriever.stream_slicer, StreamSlicerTestReadDecorator) @patch.object( @@ -1164,8 +1183,9 @@ def test_read_source_single_page_single_slice(mock_http_stream): streams = source.streams(config) for s in streams: - assert isinstance(s.retriever, SimpleRetriever) - assert isinstance(s.retriever.stream_slicer, StreamSlicerTestReadDecorator) + retriever = get_retriever(s) + assert isinstance(retriever, SimpleRetriever) + assert isinstance(retriever.stream_slicer, StreamSlicerTestReadDecorator) @pytest.mark.parametrize( diff --git a/unit_tests/sources/declarative/decoders/test_decoders_memory_usage.py b/unit_tests/sources/declarative/decoders/test_decoders_memory_usage.py index 6901c6382..2960c5802 100644 --- a/unit_tests/sources/declarative/decoders/test_decoders_memory_usage.py +++ b/unit_tests/sources/declarative/decoders/test_decoders_memory_usage.py @@ -93,9 +93,8 @@ def get_body(): requests_mock.get("https://for-all-mankind.nasa.com/api/v1/users/users3", body=get_body()) requests_mock.get("https://for-all-mankind.nasa.com/api/v1/users/users4", body=get_body()) - stream_slices = list(stream.stream_slices(sync_mode=SyncMode.full_refresh)) - for stream_slice in stream_slices: - for _ in stream.retriever.read_records(records_schema={}, stream_slice=stream_slice): + for partition in stream.generate_partitions(): + for _ in partition.read(): counter += 1 - assert counter == lines_in_response * len(stream_slices) + assert counter == lines_in_response * 4 # 4 partitions diff --git a/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py b/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py index 13d1194dd..ba26f7c91 100644 --- a/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py +++ b/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py @@ -20,6 +20,7 @@ ConcurrentDeclarativeSource, ) from airbyte_cdk.sources.declarative.incremental import ConcurrentPerPartitionCursor +from airbyte_cdk.sources.declarative.schema import InlineSchemaLoader from airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator import ( DeclarativePartition, ) @@ -31,6 +32,8 @@ from airbyte_cdk.test.catalog_builder import CatalogBuilder, ConfiguredAirbyteStreamBuilder from airbyte_cdk.test.entrypoint_wrapper import EntrypointOutput, read +_EMPTY_SCHEMA_LOADER = InlineSchemaLoader(schema={}, parameters={}) + SUBSTREAM_MANIFEST: MutableMapping[str, Any] = { "version": "0.51.42", "type": "DeclarativeSource", @@ -3614,7 +3617,13 @@ def test_given_no_partitions_processed_when_close_partition_then_no_state_update slices = list(cursor.stream_slices()) # Call once for slice in slices: cursor.close_partition( - DeclarativePartition("test_stream", {}, MagicMock(), MagicMock(), slice) + DeclarativePartition( + stream_name="test_stream", + schema_loader=_EMPTY_SCHEMA_LOADER, + retriever=MagicMock(), + message_repository=MagicMock(), + stream_slice=slice, + ) ) assert cursor.state == { @@ -3692,7 +3701,13 @@ def test_given_unfinished_first_parent_partition_no_parent_state_update(): # Close all partitions except from the first one for slice in slices[1:]: cursor.close_partition( - DeclarativePartition("test_stream", {}, MagicMock(), MagicMock(), slice) + DeclarativePartition( + stream_name="test_stream", + schema_loader=_EMPTY_SCHEMA_LOADER, + retriever=MagicMock(), + message_repository=MagicMock(), + stream_slice=slice, + ) ) cursor.ensure_at_least_one_state_emitted() @@ -3780,7 +3795,13 @@ def test_given_unfinished_last_parent_partition_with_partial_parent_state_update # Close all partitions except from the first one for slice in slices[:-1]: cursor.close_partition( - DeclarativePartition("test_stream", {}, MagicMock(), MagicMock(), slice) + DeclarativePartition( + stream_name="test_stream", + schema_loader=_EMPTY_SCHEMA_LOADER, + retriever=MagicMock(), + message_repository=MagicMock(), + stream_slice=slice, + ) ) cursor.ensure_at_least_one_state_emitted() @@ -3863,7 +3884,13 @@ def test_given_all_partitions_finished_when_close_partition_then_final_state_emi slices = list(cursor.stream_slices()) for slice in slices: cursor.close_partition( - DeclarativePartition("test_stream", {}, MagicMock(), MagicMock(), slice) + DeclarativePartition( + stream_name="test_stream", + schema_loader=_EMPTY_SCHEMA_LOADER, + retriever=MagicMock(), + message_repository=MagicMock(), + stream_slice=slice, + ) ) cursor.ensure_at_least_one_state_emitted() @@ -3930,7 +3957,13 @@ def test_given_partition_limit_exceeded_when_close_partition_then_switch_to_glob slices = list(cursor.stream_slices()) for slice in slices: cursor.close_partition( - DeclarativePartition("test_stream", {}, MagicMock(), MagicMock(), slice) + DeclarativePartition( + stream_name="test_stream", + schema_loader=_EMPTY_SCHEMA_LOADER, + retriever=MagicMock(), + message_repository=MagicMock(), + stream_slice=slice, + ) ) cursor.ensure_at_least_one_state_emitted() @@ -4007,7 +4040,15 @@ def test_semaphore_cleanup(): # Close partitions to acquire semaphores (value back to 0) for s in generated_slices: - cursor.close_partition(DeclarativePartition("test_stream", {}, MagicMock(), MagicMock(), s)) + cursor.close_partition( + DeclarativePartition( + stream_name="test_stream", + schema_loader=_EMPTY_SCHEMA_LOADER, + retriever=MagicMock(), + message_repository=MagicMock(), + stream_slice=s, + ) + ) # Check state after closing partitions assert len(cursor._partitions_done_generating_stream_slices) == 0 @@ -4119,15 +4160,35 @@ def test_duplicate_partition_after_closing_partition_cursor_deleted(): first_1 = next(slice_gen) cursor.close_partition( - DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), first_1) + DeclarativePartition( + stream_name="dup_stream", + schema_loader=_EMPTY_SCHEMA_LOADER, + retriever=MagicMock(), + message_repository=MagicMock(), + stream_slice=first_1, + ) ) two = next(slice_gen) - cursor.close_partition(DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), two)) + cursor.close_partition( + DeclarativePartition( + stream_name="dup_stream", + schema_loader=_EMPTY_SCHEMA_LOADER, + retriever=MagicMock(), + message_repository=MagicMock(), + stream_slice=two, + ) + ) second_1 = next(slice_gen) cursor.close_partition( - DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), second_1) + DeclarativePartition( + stream_name="dup_stream", + schema_loader=_EMPTY_SCHEMA_LOADER, + retriever=MagicMock(), + message_repository=MagicMock(), + stream_slice=second_1, + ) ) assert cursor._IS_PARTITION_DUPLICATION_LOGGED is False # No duplicate detected @@ -4181,16 +4242,36 @@ def test_duplicate_partition_after_closing_partition_cursor_exists(): first_1 = next(slice_gen) cursor.close_partition( - DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), first_1) + DeclarativePartition( + stream_name="dup_stream", + schema_loader=_EMPTY_SCHEMA_LOADER, + retriever=MagicMock(), + message_repository=MagicMock(), + stream_slice=first_1, + ) ) two = next(slice_gen) - cursor.close_partition(DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), two)) + cursor.close_partition( + DeclarativePartition( + stream_name="dup_stream", + schema_loader=_EMPTY_SCHEMA_LOADER, + retriever=MagicMock(), + message_repository=MagicMock(), + stream_slice=two, + ) + ) # Second “1” should appear because the semaphore was cleaned up second_1 = next(slice_gen) cursor.close_partition( - DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), second_1) + DeclarativePartition( + stream_name="dup_stream", + schema_loader=_EMPTY_SCHEMA_LOADER, + retriever=MagicMock(), + message_repository=MagicMock(), + stream_slice=second_1, + ) ) with pytest.raises(StopIteration): @@ -4241,11 +4322,23 @@ def test_duplicate_partition_while_processing(): # Close “2” first cursor.close_partition( - DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), generated[1]) + DeclarativePartition( + stream_name="dup_stream", + schema_loader=_EMPTY_SCHEMA_LOADER, + retriever=MagicMock(), + message_repository=MagicMock(), + stream_slice=generated[1], + ) ) # Now close the initial “1” cursor.close_partition( - DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), generated[0]) + DeclarativePartition( + stream_name="dup_stream", + schema_loader=_EMPTY_SCHEMA_LOADER, + retriever=MagicMock(), + message_repository=MagicMock(), + stream_slice=generated[0], + ) ) assert cursor._IS_PARTITION_DUPLICATION_LOGGED is True # warning emitted diff --git a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py index f7df8d6d4..aa8d0d781 100644 --- a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py +++ b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py @@ -157,6 +157,9 @@ from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader from airbyte_cdk.sources.declarative.spec import Spec from airbyte_cdk.sources.declarative.stream_slicers import StreamSlicerTestReadDecorator +from airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator import ( + SchemaLoaderCachingDecorator, +) from airbyte_cdk.sources.declarative.transformations import AddFields, RemoveFields from airbyte_cdk.sources.declarative.transformations.add_fields import AddedFieldDefinition from airbyte_cdk.sources.declarative.yaml_declarative_source import YamlDeclarativeSource @@ -168,6 +171,7 @@ WeekClampingStrategy, ) from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, CursorField +from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import ( CustomFormatConcurrentStreamStateConverter, ) @@ -1060,152 +1064,6 @@ def test_stream_with_incremental_and_async_retriever_with_partition_router(use_l assert stream_slices == expected_stream_slices -def test_resumable_full_refresh_stream(): - content = """ -decoder: - type: JsonDecoder -extractor: - type: DpathExtractor -selector: - type: RecordSelector - record_filter: - type: RecordFilter - condition: "{{ record['id'] > stream_state['id'] }}" -metadata_paginator: - type: DefaultPaginator - page_size_option: - type: RequestOption - inject_into: body_json - field_path: ["variables", "page_size"] - page_token_option: - type: RequestPath - pagination_strategy: - type: "CursorPagination" - cursor_value: "{{ response._metadata.next }}" - page_size: 10 -requester: - type: HttpRequester - url_base: "https://api.sendgrid.com/v3/" - http_method: "GET" - authenticator: - type: BearerAuthenticator - api_token: "{{ config['apikey'] }}" - request_parameters: - unit: "day" -retriever: - paginator: - type: NoPagination - decoder: - $ref: "#/decoder" -partial_stream: - type: DeclarativeStream - schema_loader: - type: JsonFileSchemaLoader - file_path: "./source_sendgrid/schemas/{{ parameters.name }}.json" -list_stream: - $ref: "#/partial_stream" - $parameters: - name: "lists" - extractor: - $ref: "#/extractor" - field_path: ["{{ parameters['name'] }}"] - name: "lists" - primary_key: "id" - retriever: - $ref: "#/retriever" - requester: - $ref: "#/requester" - path: "{{ next_page_token['next_page_url'] }}" - paginator: - $ref: "#/metadata_paginator" - record_selector: - $ref: "#/selector" - transformations: - - type: AddFields - fields: - - path: ["extra"] - value: "{{ response.to_add }}" -check: - type: CheckStream - stream_names: ["list_stream"] -spec: - type: Spec - documentation_url: https://airbyte.com/#yaml-from-manifest - connection_specification: - title: Test Spec - type: object - required: - - api_key - additionalProperties: false - properties: - api_key: - type: string - airbyte_secret: true - title: API Key - description: Test API Key - order: 0 - advanced_auth: - auth_flow_type: "oauth2.0" - """ - parsed_manifest = YamlDeclarativeSource._parse(content) - resolved_manifest = resolver.preprocess_manifest(parsed_manifest) - resolved_manifest["type"] = "DeclarativeSource" - manifest = transformer.propagate_types_and_parameters("", resolved_manifest, {}) - - stream_manifest = manifest["list_stream"] - assert stream_manifest["type"] == "DeclarativeStream" - stream = factory.create_component( - model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config - ) - - assert isinstance(stream, DeclarativeStream) - assert stream.primary_key == "id" - assert stream.name == "lists" - assert stream._stream_cursor_field.string == "" - - assert isinstance(stream.retriever, SimpleRetriever) - assert stream.retriever.primary_key == stream.primary_key - assert stream.retriever.name == stream.name - - assert isinstance(stream.retriever.record_selector, RecordSelector) - - assert isinstance(stream.retriever.stream_slicer, ResumableFullRefreshCursor) - assert isinstance(stream.retriever.cursor, ResumableFullRefreshCursor) - - assert isinstance(stream.retriever.paginator, DefaultPaginator) - assert isinstance(stream.retriever.paginator.decoder, PaginationDecoderDecorator) - for string in stream.retriever.paginator.page_size_option.field_path: - assert isinstance(string, InterpolatedString) - assert len(stream.retriever.paginator.page_size_option.field_path) == 2 - assert stream.retriever.paginator.page_size_option.inject_into == RequestOptionType.body_json - assert isinstance(stream.retriever.paginator.page_token_option, RequestPath) - assert stream.retriever.paginator.url_base.string == "https://api.sendgrid.com/v3/" - assert stream.retriever.paginator.url_base.default == "https://api.sendgrid.com/v3/" - - assert isinstance(stream.retriever.paginator.pagination_strategy, CursorPaginationStrategy) - assert isinstance( - stream.retriever.paginator.pagination_strategy.decoder, PaginationDecoderDecorator - ) - assert ( - stream.retriever.paginator.pagination_strategy._cursor_value.string - == "{{ response._metadata.next }}" - ) - assert ( - stream.retriever.paginator.pagination_strategy._cursor_value.default - == "{{ response._metadata.next }}" - ) - assert stream.retriever.paginator.pagination_strategy.page_size == 10 - - checker = factory.create_component( - model_type=CheckStreamModel, component_definition=manifest["check"], config=input_config - ) - - assert isinstance(checker, CheckStream) - streams_to_check = checker.stream_names - assert len(streams_to_check) == 1 - assert list(streams_to_check)[0] == "list_stream" - - def test_incremental_data_feed(): content = """ selector: @@ -1908,38 +1766,33 @@ def test_config_with_defaults(): model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config ) - assert isinstance(stream, DeclarativeStream) - assert stream.primary_key == "id" + assert isinstance(stream, DefaultStream) assert stream.name == "lists" - assert isinstance(stream.retriever, SimpleRetriever) - assert stream.retriever.name == stream.name - assert stream.retriever.primary_key == stream.primary_key + retriever = stream._stream_partition_generator._partition_factory._retriever + assert isinstance(retriever, SimpleRetriever) + assert retriever.name == stream.name + assert retriever.primary_key == "id" - assert isinstance(stream.schema_loader, JsonFileSchemaLoader) - assert ( - stream.schema_loader.file_path.string - == "./source_sendgrid/schemas/{{ parameters.name }}.yaml" - ) - assert ( - stream.schema_loader.file_path.default - == "./source_sendgrid/schemas/{{ parameters.name }}.yaml" - ) + schema_loader = get_schema_loader(stream) + assert isinstance(schema_loader, JsonFileSchemaLoader) + assert schema_loader.file_path.string == "./source_sendgrid/schemas/{{ parameters.name }}.yaml" + assert schema_loader.file_path.default == "./source_sendgrid/schemas/{{ parameters.name }}.yaml" - assert isinstance(stream.retriever.requester, HttpRequester) - assert stream.retriever.requester.http_method == HttpMethod.GET + assert isinstance(retriever.requester, HttpRequester) + assert retriever.requester.http_method == HttpMethod.GET - assert isinstance(stream.retriever.requester.authenticator, BearerAuthenticator) - assert stream.retriever.requester.authenticator.token_provider.get_token() == "verysecrettoken" + assert isinstance(retriever.requester.authenticator, BearerAuthenticator) + assert retriever.requester.authenticator.token_provider.get_token() == "verysecrettoken" - assert isinstance(stream.retriever.record_selector, RecordSelector) - assert isinstance(stream.retriever.record_selector.extractor, DpathExtractor) - assert [ - fp.eval(input_config) for fp in stream.retriever.record_selector.extractor._field_path - ] == ["result"] + assert isinstance(retriever.record_selector, RecordSelector) + assert isinstance(retriever.record_selector.extractor, DpathExtractor) + assert [fp.eval(input_config) for fp in retriever.record_selector.extractor._field_path] == [ + "result" + ] - assert isinstance(stream.retriever.paginator, DefaultPaginator) - assert stream.retriever.paginator.url_base.string == "https://api.sendgrid.com" - assert stream.retriever.paginator.pagination_strategy.get_page_size() == 10 + assert isinstance(retriever.paginator, DefaultPaginator) + assert retriever.paginator.url_base.string == "https://api.sendgrid.com" + assert retriever.paginator.pagination_strategy.get_page_size() == 10 def test_create_default_paginator(): @@ -2335,8 +2188,8 @@ def test_no_transformations(self): config=input_config, ) - assert isinstance(stream, DeclarativeStream) - assert [] == stream.retriever.record_selector.transformations + assert isinstance(stream, DefaultStream) + assert [] == get_retriever(stream).record_selector.transformations def test_remove_fields(self): content = f""" @@ -2363,11 +2216,11 @@ def test_remove_fields(self): config=input_config, ) - assert isinstance(stream, DeclarativeStream) + assert isinstance(stream, DefaultStream) expected = [ RemoveFields(field_pointers=[["path", "to", "field1"], ["path2"]], parameters={}) ] - assert stream.retriever.record_selector.transformations == expected + assert get_retriever(stream).record_selector.transformations == expected def test_add_fields_no_value_type(self): content = f""" @@ -2526,8 +2379,8 @@ def _test_add_fields(self, content, expected): config=input_config, ) - assert isinstance(stream, DeclarativeStream) - assert stream.retriever.record_selector.transformations == expected + assert isinstance(stream, DefaultStream) + assert get_retriever(stream).record_selector.transformations == expected def test_default_schema_loader(self): component_definition = { @@ -2566,7 +2419,7 @@ def test_default_schema_loader(self): component_definition=propagated_source_config, config=input_config, ) - schema_loader = stream.schema_loader + schema_loader = get_schema_loader(stream) assert ( schema_loader.default_loader._get_json_filepath().split("/")[-1] == f"{stream.name}.json" @@ -2574,7 +2427,7 @@ def test_default_schema_loader(self): @pytest.mark.parametrize( - "incremental, partition_router, expected_type", + "incremental, partition_router, expected_router_type, expected_stream_type", [ pytest.param( { @@ -2588,6 +2441,7 @@ def test_default_schema_loader(self): }, None, DatetimeBasedCursor, + DeclarativeStream, id="test_create_simple_retriever_with_incremental", ), pytest.param( @@ -2597,7 +2451,8 @@ def test_default_schema_loader(self): "values": "{{config['repos']}}", "cursor_field": "a_key", }, - PerPartitionCursor, + ListPartitionRouter, + DefaultStream, id="test_create_simple_retriever_with_partition_router", ), pytest.param( @@ -2616,6 +2471,7 @@ def test_default_schema_loader(self): "cursor_field": "a_key", }, PerPartitionWithGlobalCursor, + DeclarativeStream, id="test_create_simple_retriever_with_incremental_and_partition_router", ), pytest.param( @@ -2641,17 +2497,21 @@ def test_default_schema_loader(self): }, ], PerPartitionWithGlobalCursor, + DeclarativeStream, id="test_create_simple_retriever_with_partition_routers_multiple_components", ), pytest.param( None, None, SinglePartitionRouter, + DefaultStream, id="test_create_simple_retriever_with_no_incremental_or_partition_router", ), ], ) -def test_merge_incremental_and_partition_router(incremental, partition_router, expected_type): +def test_merge_incremental_and_partition_router( + incremental, partition_router, expected_router_type, expected_stream_type +): stream_model = { "type": "DeclarativeStream", "retriever": { @@ -2682,22 +2542,25 @@ def test_merge_incremental_and_partition_router(incremental, partition_router, e model_type=DeclarativeStreamModel, component_definition=stream_model, config=input_config ) - assert isinstance(stream, DeclarativeStream) - assert isinstance(stream.retriever, SimpleRetriever) - assert isinstance(stream.retriever.stream_slicer, expected_type) + assert isinstance(stream, expected_stream_type) + retriever = get_retriever(stream) + assert isinstance(retriever, SimpleRetriever) + stream_slicer = ( + retriever.stream_slicer + if expected_stream_type == DeclarativeStream + else stream._stream_partition_generator._stream_slicer + ) + assert isinstance(stream_slicer, expected_router_type) if incremental and partition_router: - assert isinstance(stream.retriever.stream_slicer, PerPartitionWithGlobalCursor) + assert isinstance(retriever.stream_slicer, PerPartitionWithGlobalCursor) if isinstance(partition_router, list) and len(partition_router) > 1: assert isinstance( - stream.retriever.stream_slicer._partition_router, CartesianProductStreamSlicer + retriever.stream_slicer._partition_router, CartesianProductStreamSlicer ) - assert len(stream.retriever.stream_slicer._partition_router.stream_slicers) == len( + assert len(retriever.stream_slicer._partition_router.stream_slicers) == len( partition_router ) - elif partition_router and isinstance(partition_router, list) and len(partition_router) > 1: - assert isinstance(stream.retriever.stream_slicer, PerPartitionWithGlobalCursor) - assert len(stream.retriever.stream_slicer.stream_slicerS) == len(partition_router) def test_simple_retriever_emit_log_messages(): @@ -2865,8 +2728,10 @@ def test_create_custom_retriever(): model_type=DeclarativeStreamModel, component_definition=stream_model, config=input_config ) - assert isinstance(stream, DeclarativeStream) - assert isinstance(stream.retriever, MyCustomRetriever) + assert isinstance(stream, DefaultStream) + assert isinstance( + stream._stream_partition_generator._partition_factory._retriever, MyCustomRetriever + ) @freezegun.freeze_time("2021-01-01 00:00:00") @@ -4797,14 +4662,30 @@ def test_create_stream_with_multiple_schema_loaders(): "", resolved_manifest["stream_A"], {} ) - declarative_stream = factory.create_component( + stream = factory.create_component( model_type=DeclarativeStreamModel, component_definition=partition_router_manifest, config=input_config, ) - schema_loader = declarative_stream.schema_loader + schema_loader = get_schema_loader(stream) assert isinstance(schema_loader, CompositeSchemaLoader) assert len(schema_loader.schema_loaders) == 2 assert isinstance(schema_loader.schema_loaders[0], InlineSchemaLoader) assert isinstance(schema_loader.schema_loaders[1], InlineSchemaLoader) + + +def get_schema_loader(stream: DefaultStream): + assert isinstance( + stream._stream_partition_generator._partition_factory._schema_loader, + SchemaLoaderCachingDecorator, + ) + return stream._stream_partition_generator._partition_factory._schema_loader._decorated + + +def get_retriever(stream: Union[DeclarativeStream, DefaultStream]): + return ( + stream.retriever + if isinstance(stream, DeclarativeStream) + else stream._stream_partition_generator._partition_factory._retriever + ) diff --git a/unit_tests/sources/declarative/resolvers/test_config_components_resolver.py b/unit_tests/sources/declarative/resolvers/test_config_components_resolver.py index 2f2cbca5b..c9ca1ecd5 100644 --- a/unit_tests/sources/declarative/resolvers/test_config_components_resolver.py +++ b/unit_tests/sources/declarative/resolvers/test_config_components_resolver.py @@ -383,5 +383,6 @@ def test_component_mapping_conditions(manifest, config, expected_conditional_par for stream in source.streams(config): if stream.name in expected_conditional_params: assert ( - stream.retriever.requester._parameters == expected_conditional_params[stream.name] + stream._stream_partition_generator._partition_factory._retriever.requester._parameters + == expected_conditional_params[stream.name] ) diff --git a/unit_tests/sources/declarative/retrievers/test_simple_retriever.py b/unit_tests/sources/declarative/retrievers/test_simple_retriever.py index a1e390177..44f307a32 100644 --- a/unit_tests/sources/declarative/retrievers/test_simple_retriever.py +++ b/unit_tests/sources/declarative/retrievers/test_simple_retriever.py @@ -265,120 +265,6 @@ def test_simple_retriever_resumable_full_refresh_cursor_page_increment( assert retriever.state == {"__ab_full_refresh_sync_complete": True} -@pytest.mark.parametrize( - "initial_state, expected_reset_value, expected_next_page", - [ - pytest.param(None, None, 1, id="test_initial_sync_no_state"), - pytest.param( - { - "next_page_token": "https://for-all-mankind.nasa.com/api/v1/astronauts?next_page=tracy_stevens" - }, - "https://for-all-mankind.nasa.com/api/v1/astronauts?next_page=tracy_stevens", - "https://for-all-mankind.nasa.com/api/v1/astronauts?next_page=gordo_stevens", - id="test_reset_with_next_page_token", - ), - ], -) -def test_simple_retriever_resumable_full_refresh_cursor_reset_cursor_pagination( - initial_state, expected_reset_value, expected_next_page, requests_mock -): - expected_records = [ - Record(data={"name": "ed_baldwin"}, associated_slice=None, stream_name="users"), - Record(data={"name": "danielle_poole"}, associated_slice=None, stream_name="users"), - Record(data={"name": "tracy_stevens"}, associated_slice=None, stream_name="users"), - Record(data={"name": "deke_slayton"}, associated_slice=None, stream_name="users"), - Record(data={"name": "molly_cobb"}, associated_slice=None, stream_name="users"), - Record(data={"name": "gordo_stevens"}, associated_slice=None, stream_name="users"), - Record(data={"name": "margo_madison"}, associated_slice=None, stream_name="users"), - Record(data={"name": "ellen_waverly"}, associated_slice=None, stream_name="users"), - ] - - content = """ -name: users -type: DeclarativeStream -retriever: - type: SimpleRetriever - decoder: - type: JsonDecoder - paginator: - type: "DefaultPaginator" - page_token_option: - type: RequestPath - pagination_strategy: - type: "CursorPagination" - cursor_value: "{{ response.next_page }}" - requester: - path: /astronauts - type: HttpRequester - url_base: "https://for-all-mankind.nasa.com/api/v1" - http_method: GET - authenticator: - type: ApiKeyAuthenticator - api_token: "{{ config['api_key'] }}" - inject_into: - type: RequestOption - field_name: Api-Key - inject_into: header - request_headers: {} - request_body_json: {} - record_selector: - type: RecordSelector - extractor: - type: DpathExtractor - field_path: ["data"] - partition_router: [] -primary_key: [] - """ - - factory = ModelToComponentFactory() - stream_manifest = YamlDeclarativeSource._parse(content) - stream = factory.create_component( - model_type=DeclarativeStreamModel, component_definition=stream_manifest, config={} - ) - response_body = { - "data": [r.data for r in expected_records[:5]], - "next_page": "https://for-all-mankind.nasa.com/api/v1/astronauts?next_page=gordo_stevens", - } - requests_mock.get("https://for-all-mankind.nasa.com/api/v1/astronauts", json=response_body) - requests_mock.get( - "https://for-all-mankind.nasa.com/astronauts?next_page=tracy_stevens", json=response_body - ) - response_body_2 = { - "data": [r.data for r in expected_records[5:]], - } - requests_mock.get( - "https://for-all-mankind.nasa.com/api/v1/astronauts?next_page=gordo_stevens", - json=response_body_2, - ) - stream_slicer = ResumableFullRefreshCursor(parameters={}) - if initial_state: - stream_slicer.set_initial_state(initial_state) - stream.retriever.stream_slices = stream_slicer - stream.retriever.cursor = stream_slicer - stream_slice = list(stream_slicer.stream_slices())[0] - actual_records = [ - r for r in stream.retriever.read_records(records_schema={}, stream_slice=stream_slice) - ] - - assert len(actual_records) == 5 - assert actual_records == expected_records[:5] - assert stream.retriever.state == { - "next_page_token": "https://for-all-mankind.nasa.com/api/v1/astronauts?next_page=gordo_stevens" - } - requests_mock.get( - "https://for-all-mankind.nasa.com/astronauts?next_page=tracy_stevens", json=response_body - ) - requests_mock.get( - "https://for-all-mankind.nasa.com/astronauts?next_page=gordo_stevens", json=response_body_2 - ) - actual_records = [ - r for r in stream.retriever.read_records(records_schema={}, stream_slice=stream_slice) - ] - assert len(actual_records) == 3 - assert actual_records == expected_records[5:] - assert stream.retriever.state == {"__ab_full_refresh_sync_complete": True} - - def test_simple_retriever_resumable_full_refresh_cursor_reset_skip_completed_stream(): expected_records = [ Record(data={"id": "abc"}, associated_slice=None, stream_name="test_stream"), diff --git a/unit_tests/sources/declarative/stream_slicers/test_declarative_partition_generator.py b/unit_tests/sources/declarative/stream_slicers/test_declarative_partition_generator.py index 3ced03a69..b09c708ad 100644 --- a/unit_tests/sources/declarative/stream_slicers/test_declarative_partition_generator.py +++ b/unit_tests/sources/declarative/stream_slicers/test_declarative_partition_generator.py @@ -6,6 +6,7 @@ from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type from airbyte_cdk.sources.declarative.retrievers import Retriever +from airbyte_cdk.sources.declarative.schema import InlineSchemaLoader from airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator import ( DeclarativePartitionFactory, ) @@ -14,7 +15,7 @@ from airbyte_cdk.sources.types import StreamSlice _STREAM_NAME = "a_stream_name" -_JSON_SCHEMA = {"type": "object", "properties": {}} +_SCHEMA_LOADER = InlineSchemaLoader({"type": "object", "properties": {}}, {}) _A_STREAM_SLICE = StreamSlice( partition={"partition_key": "partition_value"}, cursor_slice={"cursor_key": "cursor_value"} ) @@ -34,7 +35,7 @@ def test_given_multiple_slices_partition_generator_uses_the_same_retriever(self) message_repository = Mock(spec=MessageRepository) partition_factory = DeclarativePartitionFactory( _STREAM_NAME, - _JSON_SCHEMA, + _SCHEMA_LOADER, retriever, message_repository, ) @@ -49,7 +50,7 @@ def test_given_a_mapping_when_read_then_yield_record(self) -> None: message_repository = Mock(spec=MessageRepository) partition_factory = DeclarativePartitionFactory( _STREAM_NAME, - _JSON_SCHEMA, + _SCHEMA_LOADER, retriever, message_repository, ) @@ -67,7 +68,7 @@ def test_given_not_a_record_when_read_then_send_to_message_repository(self) -> N message_repository = Mock(spec=MessageRepository) partition_factory = DeclarativePartitionFactory( _STREAM_NAME, - _JSON_SCHEMA, + _SCHEMA_LOADER, retriever, message_repository, ) diff --git a/unit_tests/sources/declarative/test_manifest_declarative_source.py b/unit_tests/sources/declarative/test_manifest_declarative_source.py index 6753e8e4e..24258f193 100644 --- a/unit_tests/sources/declarative/test_manifest_declarative_source.py +++ b/unit_tests/sources/declarative/test_manifest_declarative_source.py @@ -28,12 +28,17 @@ SyncMode, Type, ) +from airbyte_cdk.sources.declarative.concurrent_declarative_source import ( + ConcurrentDeclarativeSource, +) from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ( ModelToComponentFactory, ) from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever +from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream +from unit_tests.sources.declarative.parsers.test_model_to_component_factory import get_retriever logger = logging.getLogger("airbyte") @@ -280,8 +285,8 @@ def test_valid_manifest(self): streams = source.streams({}) assert len(streams) == 2 - assert isinstance(streams[0], DeclarativeStream) - assert isinstance(streams[1], DeclarativeStream) + assert isinstance(streams[0], DefaultStream) + assert isinstance(streams[1], DefaultStream) assert ( source.resolved_manifest["description"] == "This is a sample source connector that is very valid." @@ -1289,13 +1294,13 @@ def test_conditional_streams_manifest(self, is_sandbox, expected_stream_count): actual_streams = source.streams(config=config) assert len(actual_streams) == expected_stream_count - assert isinstance(actual_streams[0], DeclarativeStream) + assert isinstance(actual_streams[0], DefaultStream) assert actual_streams[0].name == "students" if is_sandbox: - assert isinstance(actual_streams[1], DeclarativeStream) + assert isinstance(actual_streams[1], DefaultStream) assert actual_streams[1].name == "classrooms" - assert isinstance(actual_streams[2], DeclarativeStream) + assert isinstance(actual_streams[2], DefaultStream) assert actual_streams[2].name == "clubs" assert ( @@ -1818,8 +1823,8 @@ def _create_page(response_body): [ call({}, {}, None), call( - {"next_page_token": "next"}, - {"next_page_token": "next"}, + {}, + {}, {"next_page_token": "next"}, ), ], @@ -1907,16 +1912,9 @@ def _create_page(response_body): ), [{"ABC": 0, "partition": 0}, {"AED": 1, "partition": 0}, {"ABC": 2, "partition": 1}], [ - call({"states": []}, {"partition": "0"}, None), + call({}, {"partition": "0"}, None), call( - { - "states": [ - { - "partition": {"partition": "0"}, - "cursor": {"__ab_full_refresh_sync_complete": True}, - } - ] - }, + {}, {"partition": "1"}, None, ), @@ -2022,17 +2020,10 @@ def _create_page(response_body): {"ABC": 2, "partition": 1}, ], [ - call({"states": []}, {"partition": "0"}, None), - call({"states": []}, {"partition": "0"}, {"next_page_token": "next"}), + call({}, {"partition": "0"}, None), + call({}, {"partition": "0"}, {"next_page_token": "next"}), call( - { - "states": [ - { - "partition": {"partition": "0"}, - "cursor": {"__ab_full_refresh_sync_complete": True}, - } - ] - }, + {}, {"partition": "1"}, None, ), @@ -2193,30 +2184,26 @@ def test_only_parent_streams_use_cache(): # Main stream with caching (parent for substream `applications_interviews`) assert streams[0].name == "applications" - assert streams[0].retriever.requester.use_cache + assert get_retriever(streams[0]).requester.use_cache # Substream assert streams[1].name == "applications_interviews" - assert not streams[1].retriever.requester.use_cache + + stream_1_retriever = get_retriever(streams[1]) + assert not stream_1_retriever.requester.use_cache # Parent stream created for substream - assert ( - streams[1].retriever.stream_slicer._partition_router.parent_stream_configs[0].stream.name - == "applications" - ) - assert ( - streams[1] - .retriever.stream_slicer._partition_router.parent_stream_configs[0] - .stream.retriever.requester.use_cache - ) + assert stream_1_retriever.stream_slicer.parent_stream_configs[0].stream.name == "applications" + assert stream_1_retriever.stream_slicer.parent_stream_configs[ + 0 + ].stream.retriever.requester.use_cache # Main stream without caching assert streams[2].name == "jobs" - assert not streams[2].retriever.requester.use_cache + assert not get_retriever(streams[2]).requester.use_cache def _run_read(manifest: Mapping[str, Any], stream_name: str) -> List[AirbyteMessage]: - source = ManifestDeclarativeSource(source_config=manifest) catalog = ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( @@ -2228,7 +2215,10 @@ def _run_read(manifest: Mapping[str, Any], stream_name: str) -> List[AirbyteMess ) ] ) - return list(source.read(logger, {}, catalog, {})) + config = {} + state = {} + source = ConcurrentDeclarativeSource(catalog, config, state, manifest) + return list(source.read(logger, {}, catalog, state)) def test_declarative_component_schema_valid_ref_links(): diff --git a/unit_tests/sources/streams/concurrent/test_default_stream.py b/unit_tests/sources/streams/concurrent/test_default_stream.py index 12e2b34f4..fb3428afb 100644 --- a/unit_tests/sources/streams/concurrent/test_default_stream.py +++ b/unit_tests/sources/streams/concurrent/test_default_stream.py @@ -45,6 +45,27 @@ def test_get_json_schema(self): json_schema = self._stream.get_json_schema() assert json_schema == self._json_schema + def test_json_schema_is_callable(self): + expected = {"schema": "is callable"} + json_schema_callable = lambda: expected + stream = DefaultStream( + self._partition_generator, + self._name, + json_schema_callable, + self._primary_key, + self._cursor_field, + self._logger, + FinalStateCursor( + stream_name=self._name, + stream_namespace=None, + message_repository=self._message_repository, + ), + ) + + result = stream.get_json_schema() + + assert result == expected + def test_check_for_error_raises_an_exception_if_any_of_the_futures_are_not_done(self): futures = [Mock() for _ in range(3)] for f in futures: