diff --git a/airbyte_cdk/connector_builder/connector_builder_handler.py b/airbyte_cdk/connector_builder/connector_builder_handler.py index a7d2163a9..513546737 100644 --- a/airbyte_cdk/connector_builder/connector_builder_handler.py +++ b/airbyte_cdk/connector_builder/connector_builder_handler.py @@ -3,8 +3,8 @@ # -from dataclasses import asdict -from typing import Any, Dict, List, Mapping, Optional +from dataclasses import asdict, dataclass, field +from typing import Any, ClassVar, Dict, List, Mapping from airbyte_cdk.connector_builder.test_reader import TestReader from airbyte_cdk.models import ( @@ -15,32 +15,45 @@ Type, ) from airbyte_cdk.models import Type as MessageType -from airbyte_cdk.sources.declarative.concurrent_declarative_source import ( - ConcurrentDeclarativeSource, - TestLimits, -) from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource +from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ( + ModelToComponentFactory, +) from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets from airbyte_cdk.utils.datetime_helpers import ab_datetime_now from airbyte_cdk.utils.traced_exception import AirbyteTracedException +DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE = 5 +DEFAULT_MAXIMUM_NUMBER_OF_SLICES = 5 +DEFAULT_MAXIMUM_RECORDS = 100 +DEFAULT_MAXIMUM_STREAMS = 100 + MAX_PAGES_PER_SLICE_KEY = "max_pages_per_slice" MAX_SLICES_KEY = "max_slices" MAX_RECORDS_KEY = "max_records" MAX_STREAMS_KEY = "max_streams" +@dataclass +class TestLimits: + __test__: ClassVar[bool] = False # Tell Pytest this is not a Pytest class, despite its name + + max_records: int = field(default=DEFAULT_MAXIMUM_RECORDS) + max_pages_per_slice: int = field(default=DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE) + max_slices: int = field(default=DEFAULT_MAXIMUM_NUMBER_OF_SLICES) + max_streams: int = field(default=DEFAULT_MAXIMUM_STREAMS) + + def get_limits(config: Mapping[str, Any]) -> TestLimits: command_config = config.get("__test_read_config", {}) - return TestLimits( - max_records=command_config.get(MAX_RECORDS_KEY, TestLimits.DEFAULT_MAX_RECORDS), - max_pages_per_slice=command_config.get( - MAX_PAGES_PER_SLICE_KEY, TestLimits.DEFAULT_MAX_PAGES_PER_SLICE - ), - max_slices=command_config.get(MAX_SLICES_KEY, TestLimits.DEFAULT_MAX_SLICES), - max_streams=command_config.get(MAX_STREAMS_KEY, TestLimits.DEFAULT_MAX_STREAMS), + max_pages_per_slice = ( + command_config.get(MAX_PAGES_PER_SLICE_KEY) or DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE ) + max_slices = command_config.get(MAX_SLICES_KEY) or DEFAULT_MAXIMUM_NUMBER_OF_SLICES + max_records = command_config.get(MAX_RECORDS_KEY) or DEFAULT_MAXIMUM_RECORDS + max_streams = command_config.get(MAX_STREAMS_KEY) or DEFAULT_MAXIMUM_STREAMS + return TestLimits(max_records, max_pages_per_slice, max_slices, max_streams) def should_migrate_manifest(config: Mapping[str, Any]) -> bool: @@ -62,30 +75,21 @@ def should_normalize_manifest(config: Mapping[str, Any]) -> bool: return config.get("__should_normalize", False) -def create_source( - config: Mapping[str, Any], - limits: TestLimits, - catalog: Optional[ConfiguredAirbyteCatalog], - state: Optional[List[AirbyteStateMessage]], -) -> ConcurrentDeclarativeSource[Optional[List[AirbyteStateMessage]]]: +def create_source(config: Mapping[str, Any], limits: TestLimits) -> ManifestDeclarativeSource: manifest = config["__injected_declarative_manifest"] - - # We enforce a concurrency level of 1 so that the stream is processed on a single thread - # to retain ordering for the grouping of the builder message responses. - if "concurrency_level" in manifest: - manifest["concurrency_level"]["default_concurrency"] = 1 - else: - manifest["concurrency_level"] = {"type": "ConcurrencyLevel", "default_concurrency": 1} - - return ConcurrentDeclarativeSource( - catalog=catalog, + return ManifestDeclarativeSource( config=config, - state=state, - source_config=manifest, emit_connector_builder_messages=True, + source_config=manifest, migrate_manifest=should_migrate_manifest(config), normalize_manifest=should_normalize_manifest(config), - limits=limits, + component_factory=ModelToComponentFactory( + emit_connector_builder_messages=True, + limit_pages_fetched_per_slice=limits.max_pages_per_slice, + limit_slices_fetched=limits.max_slices, + disable_retries=True, + disable_cache=True, + ), ) diff --git a/airbyte_cdk/connector_builder/main.py b/airbyte_cdk/connector_builder/main.py index 22be81c82..80cf4afa9 100644 --- a/airbyte_cdk/connector_builder/main.py +++ b/airbyte_cdk/connector_builder/main.py @@ -91,12 +91,12 @@ def handle_connector_builder_request( def handle_request(args: List[str]) -> str: command, config, catalog, state = get_config_and_catalog_from_args(args) limits = get_limits(config) - source = create_source(config=config, limits=limits, catalog=catalog, state=state) - return orjson.dumps( # type: ignore[no-any-return] # Serializer.dump() always returns AirbyteMessage + source = create_source(config, limits) + return orjson.dumps( AirbyteMessageSerializer.dump( handle_connector_builder_request(source, command, config, catalog, state, limits) ) - ).decode() + ).decode() # type: ignore[no-any-return] # Serializer.dump() always returns AirbyteMessage if __name__ == "__main__": diff --git a/airbyte_cdk/connector_builder/test_reader/helpers.py b/airbyte_cdk/connector_builder/test_reader/helpers.py index 3cc634ccb..9154610cc 100644 --- a/airbyte_cdk/connector_builder/test_reader/helpers.py +++ b/airbyte_cdk/connector_builder/test_reader/helpers.py @@ -5,7 +5,7 @@ import json from copy import deepcopy from json import JSONDecodeError -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any, Dict, List, Mapping, Optional from airbyte_cdk.connector_builder.models import ( AuxiliaryRequest, @@ -17,8 +17,6 @@ from airbyte_cdk.models import ( AirbyteLogMessage, AirbyteMessage, - AirbyteStateBlob, - AirbyteStateMessage, OrchestratorType, TraceType, ) @@ -468,7 +466,7 @@ def handle_current_slice( return StreamReadSlices( pages=current_slice_pages, slice_descriptor=current_slice_descriptor, - state=[convert_state_blob_to_mapping(latest_state_message)] if latest_state_message else [], + state=[latest_state_message] if latest_state_message else [], auxiliary_requests=auxiliary_requests if auxiliary_requests else [], ) @@ -720,23 +718,3 @@ def get_auxiliary_request_type(stream: dict, http: dict) -> str: # type: ignore Determines the type of the auxiliary request based on the stream and HTTP properties. """ return "PARENT_STREAM" if stream.get("is_substream", False) else str(http.get("type", None)) - - -def convert_state_blob_to_mapping( - state_message: Union[AirbyteStateMessage, Dict[str, Any]], -) -> Dict[str, Any]: - """ - The AirbyteStreamState stores state as an AirbyteStateBlob which deceivingly is not - a dictionary, but rather a list of kwargs fields. This in turn causes it to not be - properly turned into a dictionary when translating this back into response output - by the connector_builder_handler using asdict() - """ - - if isinstance(state_message, AirbyteStateMessage) and state_message.stream: - state_value = state_message.stream.stream_state - if isinstance(state_value, AirbyteStateBlob): - state_value_mapping = {k: v for k, v in state_value.__dict__.items()} - state_message.stream.stream_state = state_value_mapping # type: ignore # we intentionally set this as a Dict so that StreamReadSlices is translated properly in the resulting HTTP response - return state_message # type: ignore # See above, but when this is an AirbyteStateMessage we must convert AirbyteStateBlob to a Dict - else: - return state_message # type: ignore # This is guaranteed to be a Dict since we check isinstance AirbyteStateMessage above diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py index 33731e74c..09bd921e1 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py @@ -95,14 +95,11 @@ def on_partition(self, partition: Partition) -> None: """ stream_name = partition.stream_name() self._streams_to_running_partitions[stream_name].add(partition) - cursor = self._stream_name_to_instance[stream_name].cursor if self._slice_logger.should_log_slice_message(self._logger): self._message_repository.emit_message( self._slice_logger.create_slice_log_message(partition.to_slice()) ) - self._thread_pool_manager.submit( - self._partition_reader.process_partition, partition, cursor - ) + self._thread_pool_manager.submit(self._partition_reader.process_partition, partition) def on_partition_complete_sentinel( self, sentinel: PartitionCompleteSentinel @@ -115,16 +112,26 @@ def on_partition_complete_sentinel( """ partition = sentinel.partition - partitions_running = self._streams_to_running_partitions[partition.stream_name()] - if partition in partitions_running: - partitions_running.remove(partition) - # If all partitions were generated and this was the last one, the stream is done - if ( - partition.stream_name() not in self._streams_currently_generating_partitions - and len(partitions_running) == 0 - ): - yield from self._on_stream_is_done(partition.stream_name()) - yield from self._message_repository.consume_queue() + try: + if sentinel.is_successful: + stream = self._stream_name_to_instance[partition.stream_name()] + stream.cursor.close_partition(partition) + except Exception as exception: + self._flag_exception(partition.stream_name(), exception) + yield AirbyteTracedException.from_exception( + exception, stream_descriptor=StreamDescriptor(name=partition.stream_name()) + ).as_sanitized_airbyte_message() + finally: + partitions_running = self._streams_to_running_partitions[partition.stream_name()] + if partition in partitions_running: + partitions_running.remove(partition) + # If all partitions were generated and this was the last one, the stream is done + if ( + partition.stream_name() not in self._streams_currently_generating_partitions + and len(partitions_running) == 0 + ): + yield from self._on_stream_is_done(partition.stream_name()) + yield from self._message_repository.consume_queue() def on_record(self, record: Record) -> Iterable[AirbyteMessage]: """ diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_source.py b/airbyte_cdk/sources/concurrent_source/concurrent_source.py index 9ccfc1088..ffdee2dc1 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_source.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_source.py @@ -4,7 +4,7 @@ import concurrent import logging from queue import Queue -from typing import Iterable, Iterator, List, Optional +from typing import Iterable, Iterator, List from airbyte_cdk.models import AirbyteMessage from airbyte_cdk.sources.concurrent_source.concurrent_read_processor import ConcurrentReadProcessor @@ -16,7 +16,7 @@ from airbyte_cdk.sources.message import InMemoryMessageRepository, MessageRepository from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream from airbyte_cdk.sources.streams.concurrent.partition_enqueuer import PartitionEnqueuer -from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionLogger, PartitionReader +from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionReader from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.types import ( PartitionCompleteSentinel, @@ -43,7 +43,6 @@ def create( logger: logging.Logger, slice_logger: SliceLogger, message_repository: MessageRepository, - queue: Optional[Queue[QueueItem]] = None, timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS, ) -> "ConcurrentSource": is_single_threaded = initial_number_of_partitions_to_generate == 1 and num_workers == 1 @@ -60,13 +59,12 @@ def create( logger, ) return ConcurrentSource( - threadpool=threadpool, - logger=logger, - slice_logger=slice_logger, - queue=queue, - message_repository=message_repository, - initial_number_partitions_to_generate=initial_number_of_partitions_to_generate, - timeout_seconds=timeout_seconds, + threadpool, + logger, + slice_logger, + message_repository, + initial_number_of_partitions_to_generate, + timeout_seconds, ) def __init__( @@ -74,7 +72,6 @@ def __init__( threadpool: ThreadPoolManager, logger: logging.Logger, slice_logger: SliceLogger = DebugSliceLogger(), - queue: Optional[Queue[QueueItem]] = None, message_repository: MessageRepository = InMemoryMessageRepository(), initial_number_partitions_to_generate: int = 1, timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS, @@ -94,28 +91,25 @@ def __init__( self._initial_number_partitions_to_generate = initial_number_partitions_to_generate self._timeout_seconds = timeout_seconds - # We set a maxsize to for the main thread to process record items when the queue size grows. This assumes that there are less - # threads generating partitions that than are max number of workers. If it weren't the case, we could have threads only generating - # partitions which would fill the queue. This number is arbitrarily set to 10_000 but will probably need to be changed given more - # information and might even need to be configurable depending on the source - self._queue = queue or Queue(maxsize=10_000) - def read( self, streams: List[AbstractStream], ) -> Iterator[AirbyteMessage]: self._logger.info("Starting syncing") + + # We set a maxsize to for the main thread to process record items when the queue size grows. This assumes that there are less + # threads generating partitions that than are max number of workers. If it weren't the case, we could have threads only generating + # partitions which would fill the queue. This number is arbitrarily set to 10_000 but will probably need to be changed given more + # information and might even need to be configurable depending on the source + queue: Queue[QueueItem] = Queue(maxsize=10_000) concurrent_stream_processor = ConcurrentReadProcessor( streams, - PartitionEnqueuer(self._queue, self._threadpool), + PartitionEnqueuer(queue, self._threadpool), self._threadpool, self._logger, self._slice_logger, self._message_repository, - PartitionReader( - self._queue, - PartitionLogger(self._slice_logger, self._logger, self._message_repository), - ), + PartitionReader(queue), ) # Enqueue initial partition generation tasks @@ -123,7 +117,7 @@ def read( # Read from the queue until all partitions were generated and read yield from self._consume_from_queue( - self._queue, + queue, concurrent_stream_processor, ) self._threadpool.check_for_errors_and_shutdown() @@ -147,10 +141,7 @@ def _consume_from_queue( airbyte_message_or_record_or_exception, concurrent_stream_processor, ) - # In the event that a partition raises an exception, anything remaining in - # the queue will be missed because is_done() can raise an exception and exit - # out of this loop before remaining items are consumed - if queue.empty() and concurrent_stream_processor.is_done(): + if concurrent_stream_processor.is_done() and queue.empty(): # all partitions were generated and processed. we're done here break @@ -170,7 +161,5 @@ def _handle_item( yield from concurrent_stream_processor.on_partition_complete_sentinel(queue_item) elif isinstance(queue_item, Record): yield from concurrent_stream_processor.on_record(queue_item) - elif isinstance(queue_item, AirbyteMessage): - yield queue_item else: raise ValueError(f"Unknown queue item type: {type(queue_item)}") diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py index 8eca7b7dd..7accd1ac6 100644 --- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py +++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py @@ -3,11 +3,7 @@ # import logging -from dataclasses import dataclass, field -from queue import Queue -from typing import Any, ClassVar, Generic, Iterator, List, Mapping, MutableMapping, Optional, Tuple - -from airbyte_protocol_dataclasses.models import Level +from typing import Any, Generic, Iterator, List, Mapping, MutableMapping, Optional, Tuple from airbyte_cdk.models import ( AirbyteCatalog, @@ -52,8 +48,6 @@ StreamSlicerPartitionGenerator, ) from airbyte_cdk.sources.declarative.types import ConnectionDefinition -from airbyte_cdk.sources.message.concurrent_repository import ConcurrentMessageRepository -from airbyte_cdk.sources.message.repository import InMemoryMessageRepository, MessageRepository from airbyte_cdk.sources.source import TState from airbyte_cdk.sources.streams import Stream from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream @@ -64,22 +58,6 @@ from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, FinalStateCursor from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream from airbyte_cdk.sources.streams.concurrent.helpers import get_primary_key_from_stream -from airbyte_cdk.sources.streams.concurrent.partitions.types import QueueItem - - -@dataclass -class TestLimits: - __test__: ClassVar[bool] = False # Tell Pytest this is not a Pytest class, despite its name - - DEFAULT_MAX_PAGES_PER_SLICE: ClassVar[int] = 5 - DEFAULT_MAX_SLICES: ClassVar[int] = 5 - DEFAULT_MAX_RECORDS: ClassVar[int] = 100 - DEFAULT_MAX_STREAMS: ClassVar[int] = 100 - - max_records: int = field(default=DEFAULT_MAX_RECORDS) - max_pages_per_slice: int = field(default=DEFAULT_MAX_PAGES_PER_SLICE) - max_slices: int = field(default=DEFAULT_MAX_SLICES) - max_streams: int = field(default=DEFAULT_MAX_STREAMS) class ConcurrentDeclarativeSource(ManifestDeclarativeSource, Generic[TState]): @@ -95,9 +73,7 @@ def __init__( source_config: ConnectionDefinition, debug: bool = False, emit_connector_builder_messages: bool = False, - migrate_manifest: bool = False, - normalize_manifest: bool = False, - limits: Optional[TestLimits] = None, + component_factory: Optional[ModelToComponentFactory] = None, config_path: Optional[str] = None, **kwargs: Any, ) -> None: @@ -105,40 +81,22 @@ def __init__( # no longer needs to store the original incoming state. But maybe there's an edge case? self._connector_state_manager = ConnectorStateManager(state=state) # type: ignore # state is always in the form of List[AirbyteStateMessage]. The ConnectorStateManager should use generics, but this can be done later - # We set a maxsize to for the main thread to process record items when the queue size grows. This assumes that there are less - # threads generating partitions that than are max number of workers. If it weren't the case, we could have threads only generating - # partitions which would fill the queue. This number is arbitrarily set to 10_000 but will probably need to be changed given more - # information and might even need to be configurable depending on the source - queue: Queue[QueueItem] = Queue(maxsize=10_000) - message_repository = InMemoryMessageRepository( - Level.DEBUG if emit_connector_builder_messages else Level.INFO - ) - # To reduce the complexity of the concurrent framework, we are not enabling RFR with synthetic # cursors. We do this by no longer automatically instantiating RFR cursors when converting # the declarative models into runtime components. Concurrent sources will continue to checkpoint # incremental streams running in full refresh. - component_factory = ModelToComponentFactory( + component_factory = component_factory or ModelToComponentFactory( emit_connector_builder_messages=emit_connector_builder_messages, disable_resumable_full_refresh=True, - message_repository=ConcurrentMessageRepository(queue, message_repository), connector_state_manager=self._connector_state_manager, max_concurrent_async_job_count=source_config.get("max_concurrent_async_job_count"), - limit_pages_fetched_per_slice=limits.max_pages_per_slice if limits else None, - limit_slices_fetched=limits.max_slices if limits else None, - disable_retries=True if limits else False, - disable_cache=True if limits else False, ) - self._limits = limits - super().__init__( source_config=source_config, config=config, debug=debug, emit_connector_builder_messages=emit_connector_builder_messages, - migrate_manifest=migrate_manifest, - normalize_manifest=normalize_manifest, component_factory=component_factory, config_path=config_path, ) @@ -168,7 +126,6 @@ def __init__( initial_number_of_partitions_to_generate=initial_number_of_partitions_to_generate, logger=self.logger, slice_logger=self._slice_logger, - queue=queue, message_repository=self.message_repository, ) @@ -330,9 +287,6 @@ def _group_streams( self.message_repository, ), stream_slicer=declarative_stream.retriever.stream_slicer, - slice_limit=self._limits.max_slices - if self._limits - else None, # technically not needed because create_declarative_stream() -> create_simple_retriever() will apply the decorator. But for consistency and depending how we build create_default_stream, this may be needed later ) else: if ( @@ -364,7 +318,6 @@ def _group_streams( self.message_repository, ), stream_slicer=cursor, - slice_limit=self._limits.max_slices if self._limits else None, ) concurrent_streams.append( @@ -396,9 +349,6 @@ def _group_streams( self.message_repository, ), declarative_stream.retriever.stream_slicer, - slice_limit=self._limits.max_slices - if self._limits - else None, # technically not needed because create_declarative_stream() -> create_simple_retriever() will apply the decorator. But for consistency and depending how we build create_default_stream, this may be needed later ) final_state_cursor = FinalStateCursor( @@ -460,7 +410,6 @@ def _group_streams( self.message_repository, ), perpartition_cursor, - slice_limit=self._limits.max_slices if self._limits else None, ) concurrent_streams.append( diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index 5f450a080..e1d07fdc2 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -622,10 +622,6 @@ SchemaNormalizationModel.Default: TransformConfig.DefaultSchemaNormalization, } -# Ideally this should use the value defined in ConcurrentDeclarativeSource, but -# this would be a circular import -MAX_SLICES = 5 - class ModelToComponentFactory: EPOCH_DATETIME_FORMAT = "%s" diff --git a/airbyte_cdk/sources/declarative/requesters/http_requester.py b/airbyte_cdk/sources/declarative/requesters/http_requester.py index 3ce4c8540..6b0e65aab 100644 --- a/airbyte_cdk/sources/declarative/requesters/http_requester.py +++ b/airbyte_cdk/sources/declarative/requesters/http_requester.py @@ -168,13 +168,7 @@ def _get_url( next_page_token=next_page_token, ) - full_url = ( - self._join_url(url_base, path) - if url_base - else self._join_url(url, path) - if path - else url - ) + full_url = self._join_url(url_base, path) if url_base else url + path if path else url return full_url diff --git a/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py b/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py index 7073e48a6..94ee03a56 100644 --- a/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py +++ b/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py @@ -1,11 +1,8 @@ -# Copyright (c) 2025 Airbyte, Inc., all rights reserved. +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. -from typing import Any, Iterable, Mapping, Optional, cast +from typing import Any, Iterable, Mapping, Optional from airbyte_cdk.sources.declarative.retrievers import Retriever -from airbyte_cdk.sources.declarative.stream_slicers.stream_slicer_test_read_decorator import ( - StreamSlicerTestReadDecorator, -) from airbyte_cdk.sources.message import MessageRepository from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator @@ -86,23 +83,10 @@ def __hash__(self) -> int: class StreamSlicerPartitionGenerator(PartitionGenerator): def __init__( - self, - partition_factory: DeclarativePartitionFactory, - stream_slicer: StreamSlicer, - slice_limit: Optional[int] = None, + self, partition_factory: DeclarativePartitionFactory, stream_slicer: StreamSlicer ) -> None: self._partition_factory = partition_factory - - if slice_limit: - self._stream_slicer = cast( - StreamSlicer, - StreamSlicerTestReadDecorator( - wrapped_slicer=stream_slicer, - maximum_number_of_slices=slice_limit, - ), - ) - else: - self._stream_slicer = stream_slicer + self._stream_slicer = stream_slicer def generate(self) -> Iterable[Partition]: for stream_slice in self._stream_slicer.stream_slices(): diff --git a/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer_test_read_decorator.py b/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer_test_read_decorator.py index d261c27e8..323c89196 100644 --- a/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer_test_read_decorator.py +++ b/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer_test_read_decorator.py @@ -4,10 +4,10 @@ from dataclasses import dataclass from itertools import islice -from typing import Any, Iterable +from typing import Any, Iterable, Mapping, Optional, Union from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import StreamSlicer -from airbyte_cdk.sources.types import StreamSlice +from airbyte_cdk.sources.types import StreamSlice, StreamState @dataclass diff --git a/airbyte_cdk/sources/message/concurrent_repository.py b/airbyte_cdk/sources/message/concurrent_repository.py deleted file mode 100644 index 947ee4c46..000000000 --- a/airbyte_cdk/sources/message/concurrent_repository.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) 2025 Airbyte, Inc., all rights reserved. - -from queue import Queue -from typing import Callable, Iterable - -from airbyte_cdk.models import AirbyteMessage, Level -from airbyte_cdk.sources.message.repository import LogMessage, MessageRepository -from airbyte_cdk.sources.streams.concurrent.partitions.types import QueueItem - - -class ConcurrentMessageRepository(MessageRepository): - """ - Message repository that immediately loads messages onto the queue processed on the - main thread. This ensures that messages are processed in the correct order they are - received. The InMemoryMessageRepository implementation does not have guaranteed - ordering since whether to process the main thread vs. partitions is non-deterministic - and there can be a lag between reading the main-thread and consuming messages on the - MessageRepository. - - This is particularly important for the connector builder which relies on grouping - of messages to organize request/response, pages, and partitions. - """ - - def __init__(self, queue: Queue[QueueItem], message_repository: MessageRepository): - self._queue = queue - self._decorated_message_repository = message_repository - - def emit_message(self, message: AirbyteMessage) -> None: - self._decorated_message_repository.emit_message(message) - for message in self._decorated_message_repository.consume_queue(): - self._queue.put(message) - - def log_message(self, level: Level, message_provider: Callable[[], LogMessage]) -> None: - self._decorated_message_repository.log_message(level, message_provider) - for message in self._decorated_message_repository.consume_queue(): - self._queue.put(message) - - def consume_queue(self) -> Iterable[AirbyteMessage]: - """ - This method shouldn't need to be called because as part of emit_message() we are already - loading messages onto the queue processed on the main thread. - """ - yield from [] diff --git a/airbyte_cdk/sources/streams/concurrent/partition_reader.py b/airbyte_cdk/sources/streams/concurrent/partition_reader.py index 39bb2de2e..3d23fd9cf 100644 --- a/airbyte_cdk/sources/streams/concurrent/partition_reader.py +++ b/airbyte_cdk/sources/streams/concurrent/partition_reader.py @@ -1,45 +1,14 @@ -# Copyright (c) 2025 Airbyte, Inc., all rights reserved. - -import logging +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# from queue import Queue -from typing import Optional from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException -from airbyte_cdk.sources.message.repository import MessageRepository -from airbyte_cdk.sources.streams.concurrent.cursor import Cursor from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.types import ( PartitionCompleteSentinel, QueueItem, ) -from airbyte_cdk.sources.utils.slice_logger import SliceLogger - - -# Since moving all the connector builder workflow to the concurrent CDK which required correct ordering -# of grouping log messages onto the main write thread using the ConcurrentMessageRepository, this -# separate flow and class that was used to log slices onto this partition's message_repository -# should just be replaced by emitting messages directly onto the repository instead of an intermediary. -class PartitionLogger: - """ - Helper class that provides a mechanism for passing a log message onto the current - partitions message repository - """ - - def __init__( - self, - slice_logger: SliceLogger, - logger: logging.Logger, - message_repository: MessageRepository, - ): - self._slice_logger = slice_logger - self._logger = logger - self._message_repository = message_repository - - def log(self, partition: Partition) -> None: - if self._slice_logger.should_log_slice_message(self._logger): - self._message_repository.emit_message( - self._slice_logger.create_slice_log_message(partition.to_slice()) - ) class PartitionReader: @@ -49,16 +18,13 @@ class PartitionReader: _IS_SUCCESSFUL = True - def __init__( - self, queue: Queue[QueueItem], partition_logger: Optional[PartitionLogger] = None - ) -> None: + def __init__(self, queue: Queue[QueueItem]) -> None: """ :param queue: The queue to put the records in. """ self._queue = queue - self._partition_logger = partition_logger - def process_partition(self, partition: Partition, cursor: Cursor) -> None: + def process_partition(self, partition: Partition) -> None: """ Process a partition and put the records in the output queue. When all the partitions are added to the queue, a sentinel is added to the queue to indicate that all the partitions have been generated. @@ -71,13 +37,8 @@ def process_partition(self, partition: Partition, cursor: Cursor) -> None: :return: None """ try: - if self._partition_logger: - self._partition_logger.log(partition) - for record in partition.read(): self._queue.put(record) - cursor.observe(record) - cursor.close_partition(partition) self._queue.put(PartitionCompleteSentinel(partition, self._IS_SUCCESSFUL)) except Exception as e: self._queue.put(StreamThreadException(e, partition.stream_name())) diff --git a/airbyte_cdk/sources/streams/concurrent/partitions/types.py b/airbyte_cdk/sources/streams/concurrent/partitions/types.py index 3ae63c242..77644c6b9 100644 --- a/airbyte_cdk/sources/streams/concurrent/partitions/types.py +++ b/airbyte_cdk/sources/streams/concurrent/partitions/types.py @@ -4,7 +4,6 @@ from typing import Any, Union -from airbyte_cdk.models import AirbyteMessage from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import ( PartitionGenerationCompletedSentinel, ) @@ -35,10 +34,5 @@ def __eq__(self, other: Any) -> bool: Typedef representing the items that can be added to the ThreadBasedConcurrentStream """ QueueItem = Union[ - Record, - Partition, - PartitionCompleteSentinel, - PartitionGenerationCompletedSentinel, - Exception, - AirbyteMessage, + Record, Partition, PartitionCompleteSentinel, PartitionGenerationCompletedSentinel, Exception ] diff --git a/airbyte_cdk/sources/utils/slice_logger.py b/airbyte_cdk/sources/utils/slice_logger.py index 4b29f3e0d..ee802a7a6 100644 --- a/airbyte_cdk/sources/utils/slice_logger.py +++ b/airbyte_cdk/sources/utils/slice_logger.py @@ -11,10 +11,6 @@ from airbyte_cdk.models import Type as MessageType -# Once everything runs on the concurrent CDK and we've cleaned up the legacy flows, we should try to remove -# this class and write messages directly to the message_repository instead of through the logger because for -# cases like the connector builder where ordering of messages is important, using the logger can cause -# messages to be grouped out of order. Alas work for a different day. class SliceLogger(ABC): """ SliceLogger is an interface that allows us to log slices of data in a uniform way. diff --git a/unit_tests/connector_builder/test_connector_builder_handler.py b/unit_tests/connector_builder/test_connector_builder_handler.py index 993b42e8e..2587fb95a 100644 --- a/unit_tests/connector_builder/test_connector_builder_handler.py +++ b/unit_tests/connector_builder/test_connector_builder_handler.py @@ -17,6 +17,9 @@ from airbyte_cdk import connector_builder from airbyte_cdk.connector_builder.connector_builder_handler import ( + DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE, + DEFAULT_MAXIMUM_NUMBER_OF_SLICES, + DEFAULT_MAXIMUM_RECORDS, TestLimits, create_source, get_limits, @@ -53,11 +56,8 @@ Type, ) from airbyte_cdk.models import Type as MessageType -from airbyte_cdk.sources.declarative.concurrent_declarative_source import ( - ConcurrentDeclarativeSource, - TestLimits, -) from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream +from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever from airbyte_cdk.sources.declarative.stream_slicers import StreamSlicerTestReadDecorator from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse @@ -530,9 +530,7 @@ def test_resolve_manifest(valid_resolve_manifest_config_file): config = copy.deepcopy(RESOLVE_MANIFEST_CONFIG) command = "resolve_manifest" config["__command"] = command - source = ConcurrentDeclarativeSource( - catalog=None, config=config, state=None, source_config=MANIFEST - ) + source = ManifestDeclarativeSource(source_config=MANIFEST) limits = TestLimits() resolved_manifest = handle_connector_builder_request( source, command, config, create_configured_catalog("dummy_stream"), _A_STATE, limits @@ -681,21 +679,19 @@ def test_resolve_manifest(valid_resolve_manifest_config_file): def test_resolve_manifest_error_returns_error_response(): - class MockConcurrentDeclarativeSource: + class MockManifestDeclarativeSource: @property def resolved_manifest(self): raise ValueError - source = MockConcurrentDeclarativeSource() + source = MockManifestDeclarativeSource() response = resolve_manifest(source) assert "Error resolving manifest" in response.trace.error.message def test_read(): config = TEST_READ_CONFIG - source = ConcurrentDeclarativeSource( - catalog=None, config=config, state=None, source_config=MANIFEST - ) + source = ManifestDeclarativeSource(source_config=MANIFEST) real_record = AirbyteRecordMessage( data={"id": "1234", "key": "value"}, emitted_at=1, stream=_stream_name @@ -784,9 +780,7 @@ def test_config_update() -> None: "client_secret": "a client secret", "refresh_token": "a refresh token", } - source = ConcurrentDeclarativeSource( - catalog=None, config=config, state=None, source_config=manifest - ) + source = ManifestDeclarativeSource(source_config=manifest) refresh_request_response = { "access_token": "an updated access token", @@ -823,7 +817,7 @@ def cursor_field(self): def name(self): return _stream_name - class MockConcurrentDeclarativeSource: + class MockManifestDeclarativeSource: def streams(self, config): return [MockDeclarativeStream()] @@ -845,7 +839,7 @@ def check_config_against_spec(self) -> Literal[False]: stack_trace = "a stack trace" mock_from_exception.return_value = stack_trace - source = MockConcurrentDeclarativeSource() + source = MockManifestDeclarativeSource() limits = TestLimits() response = read_stream( source, @@ -887,22 +881,19 @@ def test_handle_429_response(): config = TEST_READ_CONFIG limits = TestLimits() - catalog = ConfiguredAirbyteCatalogSerializer.load(CONFIGURED_CATALOG) - source = create_source(config=config, limits=limits, catalog=catalog, state=None) + source = create_source(config, limits) with patch("requests.Session.send", return_value=response) as mock_send: response = handle_connector_builder_request( source, "test_read", config, - catalog, + ConfiguredAirbyteCatalogSerializer.load(CONFIGURED_CATALOG), _A_PER_PARTITION_STATE, limits, ) - # The test read will attempt a read for 5 partitions, and attempt 1 request - # each time that will not be retried - assert mock_send.call_count == 5 + mock_send.assert_called_once() @pytest.mark.parametrize( @@ -954,7 +945,7 @@ def test_invalid_config_command(invalid_config_file, dummy_catalog): @pytest.fixture def manifest_declarative_source(): - return mock.Mock(spec=ConcurrentDeclarativeSource, autospec=True) + return mock.Mock(spec=ManifestDeclarativeSource, autospec=True) def create_mock_retriever(name, url_base, path): @@ -979,16 +970,16 @@ def create_mock_declarative_stream(http_stream): ( "test_no_test_read_config", {}, - TestLimits.DEFAULT_MAX_RECORDS, - TestLimits.DEFAULT_MAX_SLICES, - TestLimits.DEFAULT_MAX_PAGES_PER_SLICE, + DEFAULT_MAXIMUM_RECORDS, + DEFAULT_MAXIMUM_NUMBER_OF_SLICES, + DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE, ), ( "test_no_values_set", {"__test_read_config": {}}, - TestLimits.DEFAULT_MAX_RECORDS, - TestLimits.DEFAULT_MAX_SLICES, - TestLimits.DEFAULT_MAX_PAGES_PER_SLICE, + DEFAULT_MAXIMUM_RECORDS, + DEFAULT_MAXIMUM_NUMBER_OF_SLICES, + DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE, ), ( "test_values_are_set", @@ -1016,9 +1007,9 @@ def test_create_source(): config = {"__injected_declarative_manifest": MANIFEST} - source = create_source(config=config, limits=limits, catalog=None, state=None) + source = create_source(config, limits) - assert isinstance(source, ConcurrentDeclarativeSource) + assert isinstance(source, ManifestDeclarativeSource) assert source._constructor._limit_pages_fetched_per_slice == limits.max_pages_per_slice assert source._constructor._limit_slices_fetched == limits.max_slices assert source._constructor._disable_cache @@ -1110,7 +1101,7 @@ def test_read_source(mock_http_stream): config = {"__injected_declarative_manifest": MANIFEST} - source = create_source(config=config, limits=limits, catalog=catalog, state=None) + source = create_source(config, limits) output_data = read_stream(source, config, catalog, _A_PER_PARTITION_STATE, limits).record.data slices = output_data["slices"] @@ -1158,7 +1149,7 @@ def test_read_source_single_page_single_slice(mock_http_stream): config = {"__injected_declarative_manifest": MANIFEST} - source = create_source(config=config, limits=limits, catalog=catalog, state=None) + source = create_source(config, limits) output_data = read_stream(source, config, catalog, _A_PER_PARTITION_STATE, limits).record.data slices = output_data["slices"] @@ -1245,7 +1236,7 @@ def test_handle_read_external_requests(deployment_mode, url_base, expected_error test_manifest["streams"][0]["$parameters"]["url_base"] = url_base config = {"__injected_declarative_manifest": test_manifest} - source = create_source(config=config, limits=limits, catalog=catalog, state=None) + source = create_source(config, limits) with mock.patch.dict(os.environ, {"DEPLOYMENT_MODE": deployment_mode}, clear=False): output_data = read_stream( @@ -1275,13 +1266,13 @@ def test_handle_read_external_requests(deployment_mode, url_base, expected_error pytest.param( "CLOUD", "https://10.0.27.27/tokens/bearer", - "StreamThreadException", + "AirbyteTracedException", id="test_cloud_read_with_private_endpoint", ), pytest.param( "CLOUD", "http://unsecured.protocol/tokens/bearer", - "StreamThreadException", + "InvalidSchema", id="test_cloud_read_with_unsecured_endpoint", ), pytest.param( @@ -1341,7 +1332,7 @@ def test_handle_read_external_oauth_request(deployment_mode, token_url, expected ) config = {"__injected_declarative_manifest": test_manifest} - source = create_source(config=config, limits=limits, catalog=catalog, state=None) + source = create_source(config, limits) with mock.patch.dict(os.environ, {"DEPLOYMENT_MODE": deployment_mode}, clear=False): output_data = read_stream( @@ -1398,9 +1389,7 @@ def test_read_stream_exception_with_secrets(): def test_full_resolve_manifest(valid_resolve_manifest_config_file): config = copy.deepcopy(RESOLVE_DYNAMIC_STREAM_MANIFEST_CONFIG) command = config["__command"] - source = ConcurrentDeclarativeSource( - catalog=None, config=config, state=None, source_config=DYNAMIC_STREAM_MANIFEST - ) + source = ManifestDeclarativeSource(source_config=DYNAMIC_STREAM_MANIFEST) limits = TestLimits(max_streams=2) with HttpMocker() as http_mocker: http_mocker.get( diff --git a/unit_tests/sources/declarative/requesters/test_http_requester.py b/unit_tests/sources/declarative/requesters/test_http_requester.py index 1d8f47950..8fce688d7 100644 --- a/unit_tests/sources/declarative/requesters/test_http_requester.py +++ b/unit_tests/sources/declarative/requesters/test_http_requester.py @@ -861,50 +861,6 @@ def test_join_url(test_name, base_url, path, expected_full_url): assert sent_request.url == expected_full_url -@pytest.mark.parametrize( - "test_name, url, path, expected_full_url", - [ - ("test_no_path", "https://airbyte.io/my_endpoint", None, "https://airbyte.io/my_endpoint"), - ( - "test_path_does_not_include_url", - "https://airbyte.io/my_endpoint", - "with_path", - "https://airbyte.io/my_endpoint/with_path", - ), - ( - "test_path_does_include_url", - "https://airbyte.io/my_endpoint", - "https://airbyte.io/my_endpoint/with_path", - "https://airbyte.io/my_endpoint/with_path", - ), - ( - "test_path_is_different_full_url", - "https://airbyte.io/my_endpoint", - "https://airbyte-paginated.io/my_paginated_endpoint", - "https://airbyte-paginated.io/my_paginated_endpoint", - ), - ], -) -def test_join_url_with_url_and_path(test_name, url, path, expected_full_url): - requester = HttpRequester( - name="name", - url=url, - path=path, - http_method=HttpMethod.GET, - request_options_provider=None, - config={}, - parameters={}, - error_handler=DefaultErrorHandler(parameters={}, config={}), - ) - requester._http_client._session.send = MagicMock() - response = requests.Response() - response.status_code = 200 - requester._http_client._session.send.return_value = response - requester.send_request() - sent_request: PreparedRequest = requester._http_client._session.send.call_args_list[0][0][0] - assert sent_request.url == expected_full_url - - @pytest.mark.usefixtures("mock_sleep") def test_request_attempt_count_is_tracked_across_retries(http_requester_factory): request_mock = MagicMock(spec=requests.PreparedRequest) diff --git a/unit_tests/sources/declarative/schema/test_dynamic_schema_loader.py b/unit_tests/sources/declarative/schema/test_dynamic_schema_loader.py index 20147465f..97f89879c 100644 --- a/unit_tests/sources/declarative/schema/test_dynamic_schema_loader.py +++ b/unit_tests/sources/declarative/schema/test_dynamic_schema_loader.py @@ -10,7 +10,9 @@ from airbyte_cdk.sources.declarative.concurrent_declarative_source import ( ConcurrentDeclarativeSource, - TestLimits, +) +from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ( + ModelToComponentFactory, ) from airbyte_cdk.sources.declarative.schema import DynamicSchemaLoader, SchemaTypeIdentifier from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse @@ -351,13 +353,14 @@ def test_dynamic_schema_loader_with_type_conditions(): }, }, } - source = ConcurrentDeclarativeSource( source_config=_MANIFEST_WITH_TYPE_CONDITIONS, config=_CONFIG, catalog=None, state=None, - limits=TestLimits(), # Avoid caching on the HttpClient which could result in caching the requests/responses of other tests + component_factory=ModelToComponentFactory( + disable_cache=True + ), # Avoid caching on the HttpClient which could result in caching the requests/responses of other tests ) with HttpMocker() as http_mocker: http_mocker.get( diff --git a/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py b/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py index 75b52f6b2..50695ba1e 100644 --- a/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py +++ b/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py @@ -50,10 +50,7 @@ def __init__( self._message_repository = InMemoryMessageRepository() threadpool_manager = ThreadPoolManager(threadpool, streams[0].logger) concurrent_source = ConcurrentSource( - threadpool=threadpool_manager, - logger=streams[0].logger, - slice_logger=NeverLogSliceLogger(), - message_repository=self._message_repository, + threadpool_manager, streams[0].logger, NeverLogSliceLogger(), self._message_repository ) super().__init__(concurrent_source) self._streams = streams diff --git a/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py b/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py index a681f75eb..d6ea64583 100644 --- a/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py +++ b/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py @@ -176,12 +176,10 @@ def test_handle_partition(self): self._partition_reader, ) - expected_cursor = handler._stream_name_to_instance[_ANOTHER_STREAM_NAME].cursor - handler.on_partition(self._a_closed_partition) self._thread_pool_manager.submit.assert_called_with( - self._partition_reader.process_partition, self._a_closed_partition, expected_cursor + self._partition_reader.process_partition, self._a_closed_partition ) assert ( self._a_closed_partition in handler._streams_to_running_partitions[_ANOTHER_STREAM_NAME] @@ -203,12 +201,10 @@ def test_handle_partition_emits_log_message_if_it_should_be_logged(self): self._partition_reader, ) - expected_cursor = handler._stream_name_to_instance[_STREAM_NAME].cursor - handler.on_partition(self._an_open_partition) self._thread_pool_manager.submit.assert_called_with( - self._partition_reader.process_partition, self._an_open_partition, expected_cursor + self._partition_reader.process_partition, self._an_open_partition ) self._message_repository.emit_message.assert_called_with(self._log_message) @@ -257,6 +253,8 @@ def test_handle_on_partition_complete_sentinel_with_messages_from_repository(sel ] assert messages == expected_messages + self._stream.cursor.close_partition.assert_called_once() + @freezegun.freeze_time("2020-01-01T00:00:00") def test_handle_on_partition_complete_sentinel_yields_status_message_if_the_stream_is_done( self, @@ -304,6 +302,55 @@ def test_handle_on_partition_complete_sentinel_yields_status_message_if_the_stre ) ] assert messages == expected_messages + self._another_stream.cursor.close_partition.assert_called_once() + + @freezegun.freeze_time("2020-01-01T00:00:00") + def test_given_exception_on_partition_complete_sentinel_then_yield_error_trace_message_and_stream_is_incomplete( + self, + ) -> None: + self._a_closed_partition.stream_name.return_value = self._stream.name + self._stream.cursor.close_partition.side_effect = ValueError + + handler = ConcurrentReadProcessor( + [self._stream], + self._partition_enqueuer, + self._thread_pool_manager, + self._logger, + self._slice_logger, + self._message_repository, + self._partition_reader, + ) + handler.start_next_partition_generator() + handler.on_partition(self._a_closed_partition) + list( + handler.on_partition_generation_completed( + PartitionGenerationCompletedSentinel(self._stream) + ) + ) + messages = list( + handler.on_partition_complete_sentinel( + PartitionCompleteSentinel(self._a_closed_partition) + ) + ) + + expected_status_message = AirbyteMessage( + type=MessageType.TRACE, + trace=AirbyteTraceMessage( + type=TraceType.STREAM_STATUS, + stream_status=AirbyteStreamStatusTraceMessage( + stream_descriptor=StreamDescriptor( + name=self._stream.name, + ), + status=AirbyteStreamStatus.INCOMPLETE, + ), + emitted_at=1577836800000.0, + ), + ) + assert list(map(lambda message: message.trace.type, messages)) == [ + TraceType.ERROR, + TraceType.STREAM_STATUS, + ] + assert messages[1] == expected_status_message @freezegun.freeze_time("2020-01-01T00:00:00") def test_handle_on_partition_complete_sentinel_yields_no_status_message_if_the_stream_is_not_done( @@ -332,6 +379,7 @@ def test_handle_on_partition_complete_sentinel_yields_no_status_message_if_the_s expected_messages = [] assert messages == expected_messages + self._stream.cursor.close_partition.assert_called_once() @freezegun.freeze_time("2020-01-01T00:00:00") def test_on_record_no_status_message_no_repository_messge(self): diff --git a/unit_tests/sources/streams/concurrent/test_partition_reader.py b/unit_tests/sources/streams/concurrent/test_partition_reader.py index a41750772..1910e034d 100644 --- a/unit_tests/sources/streams/concurrent/test_partition_reader.py +++ b/unit_tests/sources/streams/concurrent/test_partition_reader.py @@ -1,5 +1,6 @@ -# Copyright (c) 2025 Airbyte, Inc., all rights reserved. - +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# import unittest from queue import Queue from typing import Callable, Iterable, List @@ -7,9 +8,7 @@ import pytest -from airbyte_cdk import InMemoryMessageRepository from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException -from airbyte_cdk.sources.streams.concurrent.cursor import FinalStateCursor from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionReader from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.types import ( @@ -27,15 +26,10 @@ class PartitionReaderTest(unittest.TestCase): def setUp(self) -> None: self._queue: Queue[QueueItem] = Queue() - self._partition_reader = PartitionReader(self._queue, None) + self._partition_reader = PartitionReader(self._queue) def test_given_no_records_when_process_partition_then_only_emit_sentinel(self): - cursor = FinalStateCursor( - stream_name="test", - stream_namespace=None, - message_repository=InMemoryMessageRepository(), - ) - self._partition_reader.process_partition(self._a_partition([]), cursor) + self._partition_reader.process_partition(self._a_partition([])) while queue_item := self._queue.get(): if not isinstance(queue_item, PartitionCompleteSentinel): @@ -46,24 +40,19 @@ def test_given_read_partition_successful_when_process_partition_then_queue_recor self, ): partition = self._a_partition(_RECORDS) - cursor = Mock() - self._partition_reader.process_partition(partition, cursor) + self._partition_reader.process_partition(partition) queue_content = self._consume_queue() assert queue_content == _RECORDS + [PartitionCompleteSentinel(partition)] - cursor.observe.assert_called() - cursor.close_partition.assert_called_once() - - def test_given_exception_from_read_when_process_partition_then_queue_records_and_exception_and_sentinel( + def test_given_exception_when_process_partition_then_queue_records_and_exception_and_sentinel( self, ): partition = Mock() - cursor = Mock() exception = ValueError() partition.read.side_effect = self._read_with_exception(_RECORDS, exception) - self._partition_reader.process_partition(partition, cursor) + self._partition_reader.process_partition(partition) queue_content = self._consume_queue() @@ -72,23 +61,6 @@ def test_given_exception_from_read_when_process_partition_then_queue_records_and PartitionCompleteSentinel(partition), ] - def test_given_exception_from_close_slice_when_process_partition_then_queue_records_and_exception_and_sentinel( - self, - ): - partition = self._a_partition(_RECORDS) - cursor = Mock() - exception = ValueError() - cursor.close_partition.side_effect = self._close_partition_with_exception(exception) - self._partition_reader.process_partition(partition, cursor) - - queue_content = self._consume_queue() - - # 4 total messages in queue. 2 records, 1 thread exception, 1 partition sentinel value - assert len(queue_content) == 4 - assert queue_content[:2] == _RECORDS - assert isinstance(queue_content[2], StreamThreadException) - assert queue_content[3] == PartitionCompleteSentinel(partition) - def _a_partition(self, records: List[Record]) -> Partition: partition = Mock(spec=Partition) partition.read.return_value = iter(records) @@ -104,13 +76,6 @@ def mocked_function() -> Iterable[Record]: return mocked_function - @staticmethod - def _close_partition_with_exception(exception: Exception) -> Callable[[Partition], None]: - def mocked_function(partition: Partition) -> None: - raise exception - - return mocked_function - def _consume_queue(self): queue_content = [] while queue_item := self._queue.get():