From 5f465c5220483100e503adbe36bce6a6ebd1bae9 Mon Sep 17 00:00:00 2001 From: Brian Lai <51336873+brianjlai@users.noreply.github.com> Date: Mon, 18 Aug 2025 17:16:00 -0700 Subject: [PATCH 01/17] =?UTF-8?q?Revert=20"fix:=20revert=20remerge=20concu?= =?UTF-8?q?rrent=20cdk=20builder=20change=20because=20of=20flaky=20te?= =?UTF-8?q?=E2=80=A6"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 1c9049a9366541dd9f59e686ce462cfe2803fd47. --- .../connector_builder_handler.py | 68 +++++---- airbyte_cdk/connector_builder/main.py | 6 +- .../connector_builder/test_reader/helpers.py | 26 +++- .../test_reader/message_grouper.py | 2 +- .../concurrent_read_processor.py | 35 ++--- .../concurrent_source/concurrent_source.py | 47 ++++--- .../concurrent_declarative_source.py | 97 ++++++++++--- .../parsers/model_to_component_factory.py | 4 + .../declarative_partition_generator.py | 56 ++++++-- .../stream_slicer_test_read_decorator.py | 4 +- .../sources/message/concurrent_repository.py | 43 ++++++ .../streams/concurrent/partition_reader.py | 51 ++++++- .../streams/concurrent/partitions/types.py | 8 +- airbyte_cdk/sources/utils/slice_logger.py | 4 + .../test_connector_builder_handler.py | 71 ++++++---- .../connector_builder/test_message_grouper.py | 120 ---------------- .../test_concurrent_perpartitioncursor.py | 132 ++++++++++++++++-- .../retrievers/test_simple_retriever.py | 9 +- .../schema/test_dynamic_schema_loader.py | 9 +- .../test_declarative_partition_generator.py | 85 ++++++++++- .../scenarios/stream_facade_builder.py | 5 +- .../test_concurrent_read_processor.py | 60 +------- .../concurrent/test_partition_reader.py | 51 +++++-- 23 files changed, 635 insertions(+), 358 deletions(-) create mode 100644 airbyte_cdk/sources/message/concurrent_repository.py diff --git a/airbyte_cdk/connector_builder/connector_builder_handler.py b/airbyte_cdk/connector_builder/connector_builder_handler.py index 513546737..a7d2163a9 100644 --- a/airbyte_cdk/connector_builder/connector_builder_handler.py +++ b/airbyte_cdk/connector_builder/connector_builder_handler.py @@ -3,8 +3,8 @@ # -from dataclasses import asdict, dataclass, field -from typing import Any, ClassVar, Dict, List, Mapping +from dataclasses import asdict +from typing import Any, Dict, List, Mapping, Optional from airbyte_cdk.connector_builder.test_reader import TestReader from airbyte_cdk.models import ( @@ -15,45 +15,32 @@ Type, ) from airbyte_cdk.models import Type as MessageType +from airbyte_cdk.sources.declarative.concurrent_declarative_source import ( + ConcurrentDeclarativeSource, + TestLimits, +) from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource -from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ( - ModelToComponentFactory, -) from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets from airbyte_cdk.utils.datetime_helpers import ab_datetime_now from airbyte_cdk.utils.traced_exception import AirbyteTracedException -DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE = 5 -DEFAULT_MAXIMUM_NUMBER_OF_SLICES = 5 -DEFAULT_MAXIMUM_RECORDS = 100 -DEFAULT_MAXIMUM_STREAMS = 100 - MAX_PAGES_PER_SLICE_KEY = "max_pages_per_slice" MAX_SLICES_KEY = "max_slices" MAX_RECORDS_KEY = "max_records" MAX_STREAMS_KEY = "max_streams" -@dataclass -class TestLimits: - __test__: ClassVar[bool] = False # Tell Pytest this is not a Pytest class, despite its name - - max_records: int = field(default=DEFAULT_MAXIMUM_RECORDS) - max_pages_per_slice: int = field(default=DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE) - max_slices: int = field(default=DEFAULT_MAXIMUM_NUMBER_OF_SLICES) - max_streams: int = field(default=DEFAULT_MAXIMUM_STREAMS) - - def get_limits(config: Mapping[str, Any]) -> TestLimits: command_config = config.get("__test_read_config", {}) - max_pages_per_slice = ( - command_config.get(MAX_PAGES_PER_SLICE_KEY) or DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE + return TestLimits( + max_records=command_config.get(MAX_RECORDS_KEY, TestLimits.DEFAULT_MAX_RECORDS), + max_pages_per_slice=command_config.get( + MAX_PAGES_PER_SLICE_KEY, TestLimits.DEFAULT_MAX_PAGES_PER_SLICE + ), + max_slices=command_config.get(MAX_SLICES_KEY, TestLimits.DEFAULT_MAX_SLICES), + max_streams=command_config.get(MAX_STREAMS_KEY, TestLimits.DEFAULT_MAX_STREAMS), ) - max_slices = command_config.get(MAX_SLICES_KEY) or DEFAULT_MAXIMUM_NUMBER_OF_SLICES - max_records = command_config.get(MAX_RECORDS_KEY) or DEFAULT_MAXIMUM_RECORDS - max_streams = command_config.get(MAX_STREAMS_KEY) or DEFAULT_MAXIMUM_STREAMS - return TestLimits(max_records, max_pages_per_slice, max_slices, max_streams) def should_migrate_manifest(config: Mapping[str, Any]) -> bool: @@ -75,21 +62,30 @@ def should_normalize_manifest(config: Mapping[str, Any]) -> bool: return config.get("__should_normalize", False) -def create_source(config: Mapping[str, Any], limits: TestLimits) -> ManifestDeclarativeSource: +def create_source( + config: Mapping[str, Any], + limits: TestLimits, + catalog: Optional[ConfiguredAirbyteCatalog], + state: Optional[List[AirbyteStateMessage]], +) -> ConcurrentDeclarativeSource[Optional[List[AirbyteStateMessage]]]: manifest = config["__injected_declarative_manifest"] - return ManifestDeclarativeSource( + + # We enforce a concurrency level of 1 so that the stream is processed on a single thread + # to retain ordering for the grouping of the builder message responses. + if "concurrency_level" in manifest: + manifest["concurrency_level"]["default_concurrency"] = 1 + else: + manifest["concurrency_level"] = {"type": "ConcurrencyLevel", "default_concurrency": 1} + + return ConcurrentDeclarativeSource( + catalog=catalog, config=config, - emit_connector_builder_messages=True, + state=state, source_config=manifest, + emit_connector_builder_messages=True, migrate_manifest=should_migrate_manifest(config), normalize_manifest=should_normalize_manifest(config), - component_factory=ModelToComponentFactory( - emit_connector_builder_messages=True, - limit_pages_fetched_per_slice=limits.max_pages_per_slice, - limit_slices_fetched=limits.max_slices, - disable_retries=True, - disable_cache=True, - ), + limits=limits, ) diff --git a/airbyte_cdk/connector_builder/main.py b/airbyte_cdk/connector_builder/main.py index 80cf4afa9..22be81c82 100644 --- a/airbyte_cdk/connector_builder/main.py +++ b/airbyte_cdk/connector_builder/main.py @@ -91,12 +91,12 @@ def handle_connector_builder_request( def handle_request(args: List[str]) -> str: command, config, catalog, state = get_config_and_catalog_from_args(args) limits = get_limits(config) - source = create_source(config, limits) - return orjson.dumps( + source = create_source(config=config, limits=limits, catalog=catalog, state=state) + return orjson.dumps( # type: ignore[no-any-return] # Serializer.dump() always returns AirbyteMessage AirbyteMessageSerializer.dump( handle_connector_builder_request(source, command, config, catalog, state, limits) ) - ).decode() # type: ignore[no-any-return] # Serializer.dump() always returns AirbyteMessage + ).decode() if __name__ == "__main__": diff --git a/airbyte_cdk/connector_builder/test_reader/helpers.py b/airbyte_cdk/connector_builder/test_reader/helpers.py index 9154610cc..3cc634ccb 100644 --- a/airbyte_cdk/connector_builder/test_reader/helpers.py +++ b/airbyte_cdk/connector_builder/test_reader/helpers.py @@ -5,7 +5,7 @@ import json from copy import deepcopy from json import JSONDecodeError -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional, Union from airbyte_cdk.connector_builder.models import ( AuxiliaryRequest, @@ -17,6 +17,8 @@ from airbyte_cdk.models import ( AirbyteLogMessage, AirbyteMessage, + AirbyteStateBlob, + AirbyteStateMessage, OrchestratorType, TraceType, ) @@ -466,7 +468,7 @@ def handle_current_slice( return StreamReadSlices( pages=current_slice_pages, slice_descriptor=current_slice_descriptor, - state=[latest_state_message] if latest_state_message else [], + state=[convert_state_blob_to_mapping(latest_state_message)] if latest_state_message else [], auxiliary_requests=auxiliary_requests if auxiliary_requests else [], ) @@ -718,3 +720,23 @@ def get_auxiliary_request_type(stream: dict, http: dict) -> str: # type: ignore Determines the type of the auxiliary request based on the stream and HTTP properties. """ return "PARENT_STREAM" if stream.get("is_substream", False) else str(http.get("type", None)) + + +def convert_state_blob_to_mapping( + state_message: Union[AirbyteStateMessage, Dict[str, Any]], +) -> Dict[str, Any]: + """ + The AirbyteStreamState stores state as an AirbyteStateBlob which deceivingly is not + a dictionary, but rather a list of kwargs fields. This in turn causes it to not be + properly turned into a dictionary when translating this back into response output + by the connector_builder_handler using asdict() + """ + + if isinstance(state_message, AirbyteStateMessage) and state_message.stream: + state_value = state_message.stream.stream_state + if isinstance(state_value, AirbyteStateBlob): + state_value_mapping = {k: v for k, v in state_value.__dict__.items()} + state_message.stream.stream_state = state_value_mapping # type: ignore # we intentionally set this as a Dict so that StreamReadSlices is translated properly in the resulting HTTP response + return state_message # type: ignore # See above, but when this is an AirbyteStateMessage we must convert AirbyteStateBlob to a Dict + else: + return state_message # type: ignore # This is guaranteed to be a Dict since we check isinstance AirbyteStateMessage above diff --git a/airbyte_cdk/connector_builder/test_reader/message_grouper.py b/airbyte_cdk/connector_builder/test_reader/message_grouper.py index 33b594451..999b54b72 100644 --- a/airbyte_cdk/connector_builder/test_reader/message_grouper.py +++ b/airbyte_cdk/connector_builder/test_reader/message_grouper.py @@ -95,7 +95,7 @@ def get_message_groups( latest_state_message: Optional[Dict[str, Any]] = None slice_auxiliary_requests: List[AuxiliaryRequest] = [] - while records_count < limit and (message := next(messages, None)): + while message := next(messages, None): json_message = airbyte_message_to_json(message) if is_page_http_request_for_different_stream(json_message, stream_name): diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py index 09bd921e1..33731e74c 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py @@ -95,11 +95,14 @@ def on_partition(self, partition: Partition) -> None: """ stream_name = partition.stream_name() self._streams_to_running_partitions[stream_name].add(partition) + cursor = self._stream_name_to_instance[stream_name].cursor if self._slice_logger.should_log_slice_message(self._logger): self._message_repository.emit_message( self._slice_logger.create_slice_log_message(partition.to_slice()) ) - self._thread_pool_manager.submit(self._partition_reader.process_partition, partition) + self._thread_pool_manager.submit( + self._partition_reader.process_partition, partition, cursor + ) def on_partition_complete_sentinel( self, sentinel: PartitionCompleteSentinel @@ -112,26 +115,16 @@ def on_partition_complete_sentinel( """ partition = sentinel.partition - try: - if sentinel.is_successful: - stream = self._stream_name_to_instance[partition.stream_name()] - stream.cursor.close_partition(partition) - except Exception as exception: - self._flag_exception(partition.stream_name(), exception) - yield AirbyteTracedException.from_exception( - exception, stream_descriptor=StreamDescriptor(name=partition.stream_name()) - ).as_sanitized_airbyte_message() - finally: - partitions_running = self._streams_to_running_partitions[partition.stream_name()] - if partition in partitions_running: - partitions_running.remove(partition) - # If all partitions were generated and this was the last one, the stream is done - if ( - partition.stream_name() not in self._streams_currently_generating_partitions - and len(partitions_running) == 0 - ): - yield from self._on_stream_is_done(partition.stream_name()) - yield from self._message_repository.consume_queue() + partitions_running = self._streams_to_running_partitions[partition.stream_name()] + if partition in partitions_running: + partitions_running.remove(partition) + # If all partitions were generated and this was the last one, the stream is done + if ( + partition.stream_name() not in self._streams_currently_generating_partitions + and len(partitions_running) == 0 + ): + yield from self._on_stream_is_done(partition.stream_name()) + yield from self._message_repository.consume_queue() def on_record(self, record: Record) -> Iterable[AirbyteMessage]: """ diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_source.py b/airbyte_cdk/sources/concurrent_source/concurrent_source.py index ffdee2dc1..9ccfc1088 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_source.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_source.py @@ -4,7 +4,7 @@ import concurrent import logging from queue import Queue -from typing import Iterable, Iterator, List +from typing import Iterable, Iterator, List, Optional from airbyte_cdk.models import AirbyteMessage from airbyte_cdk.sources.concurrent_source.concurrent_read_processor import ConcurrentReadProcessor @@ -16,7 +16,7 @@ from airbyte_cdk.sources.message import InMemoryMessageRepository, MessageRepository from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream from airbyte_cdk.sources.streams.concurrent.partition_enqueuer import PartitionEnqueuer -from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionReader +from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionLogger, PartitionReader from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.types import ( PartitionCompleteSentinel, @@ -43,6 +43,7 @@ def create( logger: logging.Logger, slice_logger: SliceLogger, message_repository: MessageRepository, + queue: Optional[Queue[QueueItem]] = None, timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS, ) -> "ConcurrentSource": is_single_threaded = initial_number_of_partitions_to_generate == 1 and num_workers == 1 @@ -59,12 +60,13 @@ def create( logger, ) return ConcurrentSource( - threadpool, - logger, - slice_logger, - message_repository, - initial_number_of_partitions_to_generate, - timeout_seconds, + threadpool=threadpool, + logger=logger, + slice_logger=slice_logger, + queue=queue, + message_repository=message_repository, + initial_number_partitions_to_generate=initial_number_of_partitions_to_generate, + timeout_seconds=timeout_seconds, ) def __init__( @@ -72,6 +74,7 @@ def __init__( threadpool: ThreadPoolManager, logger: logging.Logger, slice_logger: SliceLogger = DebugSliceLogger(), + queue: Optional[Queue[QueueItem]] = None, message_repository: MessageRepository = InMemoryMessageRepository(), initial_number_partitions_to_generate: int = 1, timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS, @@ -91,25 +94,28 @@ def __init__( self._initial_number_partitions_to_generate = initial_number_partitions_to_generate self._timeout_seconds = timeout_seconds + # We set a maxsize to for the main thread to process record items when the queue size grows. This assumes that there are less + # threads generating partitions that than are max number of workers. If it weren't the case, we could have threads only generating + # partitions which would fill the queue. This number is arbitrarily set to 10_000 but will probably need to be changed given more + # information and might even need to be configurable depending on the source + self._queue = queue or Queue(maxsize=10_000) + def read( self, streams: List[AbstractStream], ) -> Iterator[AirbyteMessage]: self._logger.info("Starting syncing") - - # We set a maxsize to for the main thread to process record items when the queue size grows. This assumes that there are less - # threads generating partitions that than are max number of workers. If it weren't the case, we could have threads only generating - # partitions which would fill the queue. This number is arbitrarily set to 10_000 but will probably need to be changed given more - # information and might even need to be configurable depending on the source - queue: Queue[QueueItem] = Queue(maxsize=10_000) concurrent_stream_processor = ConcurrentReadProcessor( streams, - PartitionEnqueuer(queue, self._threadpool), + PartitionEnqueuer(self._queue, self._threadpool), self._threadpool, self._logger, self._slice_logger, self._message_repository, - PartitionReader(queue), + PartitionReader( + self._queue, + PartitionLogger(self._slice_logger, self._logger, self._message_repository), + ), ) # Enqueue initial partition generation tasks @@ -117,7 +123,7 @@ def read( # Read from the queue until all partitions were generated and read yield from self._consume_from_queue( - queue, + self._queue, concurrent_stream_processor, ) self._threadpool.check_for_errors_and_shutdown() @@ -141,7 +147,10 @@ def _consume_from_queue( airbyte_message_or_record_or_exception, concurrent_stream_processor, ) - if concurrent_stream_processor.is_done() and queue.empty(): + # In the event that a partition raises an exception, anything remaining in + # the queue will be missed because is_done() can raise an exception and exit + # out of this loop before remaining items are consumed + if queue.empty() and concurrent_stream_processor.is_done(): # all partitions were generated and processed. we're done here break @@ -161,5 +170,7 @@ def _handle_item( yield from concurrent_stream_processor.on_partition_complete_sentinel(queue_item) elif isinstance(queue_item, Record): yield from concurrent_stream_processor.on_record(queue_item) + elif isinstance(queue_item, AirbyteMessage): + yield queue_item else: raise ValueError(f"Unknown queue item type: {type(queue_item)}") diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py index 2bcc4b8c9..b5949bb19 100644 --- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py +++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py @@ -3,7 +3,11 @@ # import logging -from typing import Any, Generic, Iterator, List, Mapping, MutableMapping, Optional, Tuple +from dataclasses import dataclass, field +from queue import Queue +from typing import Any, ClassVar, Generic, Iterator, List, Mapping, MutableMapping, Optional, Tuple + +from airbyte_protocol_dataclasses.models import Level from airbyte_cdk.models import ( AirbyteCatalog, @@ -48,6 +52,8 @@ StreamSlicerPartitionGenerator, ) from airbyte_cdk.sources.declarative.types import ConnectionDefinition +from airbyte_cdk.sources.message.concurrent_repository import ConcurrentMessageRepository +from airbyte_cdk.sources.message.repository import InMemoryMessageRepository, MessageRepository from airbyte_cdk.sources.source import TState from airbyte_cdk.sources.streams import Stream from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream @@ -55,6 +61,22 @@ from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, FinalStateCursor from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream from airbyte_cdk.sources.streams.concurrent.helpers import get_primary_key_from_stream +from airbyte_cdk.sources.streams.concurrent.partitions.types import QueueItem + + +@dataclass +class TestLimits: + __test__: ClassVar[bool] = False # Tell Pytest this is not a Pytest class, despite its name + + DEFAULT_MAX_PAGES_PER_SLICE: ClassVar[int] = 5 + DEFAULT_MAX_SLICES: ClassVar[int] = 5 + DEFAULT_MAX_RECORDS: ClassVar[int] = 100 + DEFAULT_MAX_STREAMS: ClassVar[int] = 100 + + max_records: int = field(default=DEFAULT_MAX_RECORDS) + max_pages_per_slice: int = field(default=DEFAULT_MAX_PAGES_PER_SLICE) + max_slices: int = field(default=DEFAULT_MAX_SLICES) + max_streams: int = field(default=DEFAULT_MAX_STREAMS) class ConcurrentDeclarativeSource(ManifestDeclarativeSource, Generic[TState]): @@ -70,7 +92,9 @@ def __init__( source_config: ConnectionDefinition, debug: bool = False, emit_connector_builder_messages: bool = False, - component_factory: Optional[ModelToComponentFactory] = None, + migrate_manifest: bool = False, + normalize_manifest: bool = False, + limits: Optional[TestLimits] = None, config_path: Optional[str] = None, **kwargs: Any, ) -> None: @@ -78,22 +102,40 @@ def __init__( # no longer needs to store the original incoming state. But maybe there's an edge case? self._connector_state_manager = ConnectorStateManager(state=state) # type: ignore # state is always in the form of List[AirbyteStateMessage]. The ConnectorStateManager should use generics, but this can be done later + # We set a maxsize to for the main thread to process record items when the queue size grows. This assumes that there are less + # threads generating partitions that than are max number of workers. If it weren't the case, we could have threads only generating + # partitions which would fill the queue. This number is arbitrarily set to 10_000 but will probably need to be changed given more + # information and might even need to be configurable depending on the source + queue: Queue[QueueItem] = Queue(maxsize=10_000) + message_repository = InMemoryMessageRepository( + Level.DEBUG if emit_connector_builder_messages else Level.INFO + ) + # To reduce the complexity of the concurrent framework, we are not enabling RFR with synthetic # cursors. We do this by no longer automatically instantiating RFR cursors when converting # the declarative models into runtime components. Concurrent sources will continue to checkpoint # incremental streams running in full refresh. - component_factory = component_factory or ModelToComponentFactory( + component_factory = ModelToComponentFactory( emit_connector_builder_messages=emit_connector_builder_messages, disable_resumable_full_refresh=True, + message_repository=ConcurrentMessageRepository(queue, message_repository), connector_state_manager=self._connector_state_manager, max_concurrent_async_job_count=source_config.get("max_concurrent_async_job_count"), + limit_pages_fetched_per_slice=limits.max_pages_per_slice if limits else None, + limit_slices_fetched=limits.max_slices if limits else None, + disable_retries=True if limits else False, + disable_cache=True if limits else False, ) + self._limits = limits + super().__init__( source_config=source_config, config=config, debug=debug, emit_connector_builder_messages=emit_connector_builder_messages, + migrate_manifest=migrate_manifest, + normalize_manifest=normalize_manifest, component_factory=component_factory, config_path=config_path, ) @@ -123,6 +165,7 @@ def __init__( initial_number_of_partitions_to_generate=initial_number_of_partitions_to_generate, logger=self.logger, slice_logger=self._slice_logger, + queue=queue, message_repository=self.message_repository, ) @@ -278,12 +321,18 @@ def _group_streams( partition_generator = StreamSlicerPartitionGenerator( partition_factory=DeclarativePartitionFactory( - declarative_stream.name, - declarative_stream.get_json_schema(), - retriever, - self.message_repository, + stream_name=declarative_stream.name, + json_schema=declarative_stream.get_json_schema(), + retriever=retriever, + message_repository=self.message_repository, + max_records_limit=self._limits.max_records + if self._limits + else None, ), stream_slicer=declarative_stream.retriever.stream_slicer, + slice_limit=self._limits.max_slices + if self._limits + else None, # technically not needed because create_declarative_stream() -> create_simple_retriever() will apply the decorator. But for consistency and depending how we build create_default_stream, this may be needed later ) else: if ( @@ -309,12 +358,16 @@ def _group_streams( ) partition_generator = StreamSlicerPartitionGenerator( partition_factory=DeclarativePartitionFactory( - declarative_stream.name, - declarative_stream.get_json_schema(), - retriever, - self.message_repository, + stream_name=declarative_stream.name, + json_schema=declarative_stream.get_json_schema(), + retriever=retriever, + message_repository=self.message_repository, + max_records_limit=self._limits.max_records + if self._limits + else None, ), stream_slicer=cursor, + slice_limit=self._limits.max_slices if self._limits else None, ) concurrent_streams.append( @@ -339,12 +392,16 @@ def _group_streams( ) and hasattr(declarative_stream.retriever, "stream_slicer"): partition_generator = StreamSlicerPartitionGenerator( DeclarativePartitionFactory( - declarative_stream.name, - declarative_stream.get_json_schema(), - declarative_stream.retriever, - self.message_repository, + stream_name=declarative_stream.name, + json_schema=declarative_stream.get_json_schema(), + retriever=declarative_stream.retriever, + message_repository=self.message_repository, + max_records_limit=self._limits.max_records if self._limits else None, ), declarative_stream.retriever.stream_slicer, + slice_limit=self._limits.max_slices + if self._limits + else None, # technically not needed because create_declarative_stream() -> create_simple_retriever() will apply the decorator. But for consistency and depending how we build create_default_stream, this may be needed later ) final_state_cursor = FinalStateCursor( @@ -399,12 +456,14 @@ def _group_streams( partition_generator = StreamSlicerPartitionGenerator( DeclarativePartitionFactory( - declarative_stream.name, - declarative_stream.get_json_schema(), - retriever, - self.message_repository, + stream_name=declarative_stream.name, + json_schema=declarative_stream.get_json_schema(), + retriever=retriever, + message_repository=self.message_repository, + max_records_limit=self._limits.max_records if self._limits else None, ), perpartition_cursor, + slice_limit=self._limits.max_slices if self._limits else None, ) concurrent_streams.append( diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index e1d07fdc2..5f450a080 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -622,6 +622,10 @@ SchemaNormalizationModel.Default: TransformConfig.DefaultSchemaNormalization, } +# Ideally this should use the value defined in ConcurrentDeclarativeSource, but +# this would be a circular import +MAX_SLICES = 5 + class ModelToComponentFactory: EPOCH_DATETIME_FORMAT = "%s" diff --git a/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py b/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py index 94ee03a56..882d760ea 100644 --- a/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py +++ b/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py @@ -1,8 +1,11 @@ -# Copyright (c) 2024 Airbyte, Inc., all rights reserved. +# Copyright (c) 2025 Airbyte, Inc., all rights reserved. -from typing import Any, Iterable, Mapping, Optional +from typing import Any, Iterable, Mapping, Optional, cast from airbyte_cdk.sources.declarative.retrievers import Retriever +from airbyte_cdk.sources.declarative.stream_slicers.stream_slicer_test_read_decorator import ( + StreamSlicerTestReadDecorator, +) from airbyte_cdk.sources.message import MessageRepository from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator @@ -10,6 +13,11 @@ from airbyte_cdk.sources.types import Record, StreamSlice from airbyte_cdk.utils.slice_hasher import SliceHasher +# For Connector Builder test read operations, we track the total number of records +# read for the stream at the global level so that we can stop reading early if we +# exceed the record limit +total_record_counter = 0 + class DeclarativePartitionFactory: def __init__( @@ -18,6 +26,7 @@ def __init__( json_schema: Mapping[str, Any], retriever: Retriever, message_repository: MessageRepository, + max_records_limit: Optional[int] = None, ) -> None: """ The DeclarativePartitionFactory takes a retriever_factory and not a retriever directly. The reason is that our components are not @@ -28,14 +37,16 @@ def __init__( self._json_schema = json_schema self._retriever = retriever self._message_repository = message_repository + self._max_records_limit = max_records_limit def create(self, stream_slice: StreamSlice) -> Partition: return DeclarativePartition( - self._stream_name, - self._json_schema, - self._retriever, - self._message_repository, - stream_slice, + stream_name=self._stream_name, + json_schema=self._json_schema, + retriever=self._retriever, + message_repository=self._message_repository, + max_records_limit=self._max_records_limit, + stream_slice=stream_slice, ) @@ -46,17 +57,27 @@ def __init__( json_schema: Mapping[str, Any], retriever: Retriever, message_repository: MessageRepository, + max_records_limit: Optional[int], stream_slice: StreamSlice, ): self._stream_name = stream_name self._json_schema = json_schema self._retriever = retriever self._message_repository = message_repository + self._max_records_limit = max_records_limit self._stream_slice = stream_slice self._hash = SliceHasher.hash(self._stream_name, self._stream_slice) def read(self) -> Iterable[Record]: + if self._max_records_limit: + global total_record_counter + if total_record_counter >= self._max_records_limit: + return for stream_data in self._retriever.read_records(self._json_schema, self._stream_slice): + if self._max_records_limit: + if total_record_counter >= self._max_records_limit: + break + if isinstance(stream_data, Mapping): record = ( stream_data @@ -71,6 +92,9 @@ def read(self) -> Iterable[Record]: else: self._message_repository.emit_message(stream_data) + if self._max_records_limit: + total_record_counter += 1 + def to_slice(self) -> Optional[Mapping[str, Any]]: return self._stream_slice @@ -83,10 +107,24 @@ def __hash__(self) -> int: class StreamSlicerPartitionGenerator(PartitionGenerator): def __init__( - self, partition_factory: DeclarativePartitionFactory, stream_slicer: StreamSlicer + self, + partition_factory: DeclarativePartitionFactory, + stream_slicer: StreamSlicer, + slice_limit: Optional[int] = None, + max_records_limit: Optional[int] = None, ) -> None: self._partition_factory = partition_factory - self._stream_slicer = stream_slicer + + if slice_limit: + self._stream_slicer = cast( + StreamSlicer, + StreamSlicerTestReadDecorator( + wrapped_slicer=stream_slicer, + maximum_number_of_slices=slice_limit, + ), + ) + else: + self._stream_slicer = stream_slicer def generate(self) -> Iterable[Partition]: for stream_slice in self._stream_slicer.stream_slices(): diff --git a/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer_test_read_decorator.py b/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer_test_read_decorator.py index 323c89196..d261c27e8 100644 --- a/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer_test_read_decorator.py +++ b/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer_test_read_decorator.py @@ -4,10 +4,10 @@ from dataclasses import dataclass from itertools import islice -from typing import Any, Iterable, Mapping, Optional, Union +from typing import Any, Iterable from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import StreamSlicer -from airbyte_cdk.sources.types import StreamSlice, StreamState +from airbyte_cdk.sources.types import StreamSlice @dataclass diff --git a/airbyte_cdk/sources/message/concurrent_repository.py b/airbyte_cdk/sources/message/concurrent_repository.py new file mode 100644 index 000000000..947ee4c46 --- /dev/null +++ b/airbyte_cdk/sources/message/concurrent_repository.py @@ -0,0 +1,43 @@ +# Copyright (c) 2025 Airbyte, Inc., all rights reserved. + +from queue import Queue +from typing import Callable, Iterable + +from airbyte_cdk.models import AirbyteMessage, Level +from airbyte_cdk.sources.message.repository import LogMessage, MessageRepository +from airbyte_cdk.sources.streams.concurrent.partitions.types import QueueItem + + +class ConcurrentMessageRepository(MessageRepository): + """ + Message repository that immediately loads messages onto the queue processed on the + main thread. This ensures that messages are processed in the correct order they are + received. The InMemoryMessageRepository implementation does not have guaranteed + ordering since whether to process the main thread vs. partitions is non-deterministic + and there can be a lag between reading the main-thread and consuming messages on the + MessageRepository. + + This is particularly important for the connector builder which relies on grouping + of messages to organize request/response, pages, and partitions. + """ + + def __init__(self, queue: Queue[QueueItem], message_repository: MessageRepository): + self._queue = queue + self._decorated_message_repository = message_repository + + def emit_message(self, message: AirbyteMessage) -> None: + self._decorated_message_repository.emit_message(message) + for message in self._decorated_message_repository.consume_queue(): + self._queue.put(message) + + def log_message(self, level: Level, message_provider: Callable[[], LogMessage]) -> None: + self._decorated_message_repository.log_message(level, message_provider) + for message in self._decorated_message_repository.consume_queue(): + self._queue.put(message) + + def consume_queue(self) -> Iterable[AirbyteMessage]: + """ + This method shouldn't need to be called because as part of emit_message() we are already + loading messages onto the queue processed on the main thread. + """ + yield from [] diff --git a/airbyte_cdk/sources/streams/concurrent/partition_reader.py b/airbyte_cdk/sources/streams/concurrent/partition_reader.py index 3d23fd9cf..0edc5056a 100644 --- a/airbyte_cdk/sources/streams/concurrent/partition_reader.py +++ b/airbyte_cdk/sources/streams/concurrent/partition_reader.py @@ -1,14 +1,45 @@ -# -# Copyright (c) 2023 Airbyte, Inc., all rights reserved. -# +# Copyright (c) 2025 Airbyte, Inc., all rights reserved. + +import logging from queue import Queue +from typing import Optional from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException +from airbyte_cdk.sources.message.repository import MessageRepository +from airbyte_cdk.sources.streams.concurrent.cursor import Cursor from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.types import ( PartitionCompleteSentinel, QueueItem, ) +from airbyte_cdk.sources.utils.slice_logger import SliceLogger + + +# Since moving all the connector builder workflow to the concurrent CDK which required correct ordering +# of grouping log messages onto the main write thread using the ConcurrentMessageRepository, this +# separate flow and class that was used to log slices onto this partition's message_repository +# should just be replaced by emitting messages directly onto the repository instead of an intermediary. +class PartitionLogger: + """ + Helper class that provides a mechanism for passing a log message onto the current + partitions message repository + """ + + def __init__( + self, + slice_logger: SliceLogger, + logger: logging.Logger, + message_repository: MessageRepository, + ): + self._slice_logger = slice_logger + self._logger = logger + self._message_repository = message_repository + + def log(self, partition: Partition) -> None: + if self._slice_logger.should_log_slice_message(self._logger): + self._message_repository.emit_message( + self._slice_logger.create_slice_log_message(partition.to_slice()) + ) class PartitionReader: @@ -18,13 +49,18 @@ class PartitionReader: _IS_SUCCESSFUL = True - def __init__(self, queue: Queue[QueueItem]) -> None: + def __init__( + self, + queue: Queue[QueueItem], + partition_logger: Optional[PartitionLogger] = None, + ) -> None: """ :param queue: The queue to put the records in. """ self._queue = queue + self._partition_logger = partition_logger - def process_partition(self, partition: Partition) -> None: + def process_partition(self, partition: Partition, cursor: Cursor) -> None: """ Process a partition and put the records in the output queue. When all the partitions are added to the queue, a sentinel is added to the queue to indicate that all the partitions have been generated. @@ -37,8 +73,13 @@ def process_partition(self, partition: Partition) -> None: :return: None """ try: + if self._partition_logger: + self._partition_logger.log(partition) + for record in partition.read(): self._queue.put(record) + cursor.observe(record) + cursor.close_partition(partition) self._queue.put(PartitionCompleteSentinel(partition, self._IS_SUCCESSFUL)) except Exception as e: self._queue.put(StreamThreadException(e, partition.stream_name())) diff --git a/airbyte_cdk/sources/streams/concurrent/partitions/types.py b/airbyte_cdk/sources/streams/concurrent/partitions/types.py index 77644c6b9..3ae63c242 100644 --- a/airbyte_cdk/sources/streams/concurrent/partitions/types.py +++ b/airbyte_cdk/sources/streams/concurrent/partitions/types.py @@ -4,6 +4,7 @@ from typing import Any, Union +from airbyte_cdk.models import AirbyteMessage from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import ( PartitionGenerationCompletedSentinel, ) @@ -34,5 +35,10 @@ def __eq__(self, other: Any) -> bool: Typedef representing the items that can be added to the ThreadBasedConcurrentStream """ QueueItem = Union[ - Record, Partition, PartitionCompleteSentinel, PartitionGenerationCompletedSentinel, Exception + Record, + Partition, + PartitionCompleteSentinel, + PartitionGenerationCompletedSentinel, + Exception, + AirbyteMessage, ] diff --git a/airbyte_cdk/sources/utils/slice_logger.py b/airbyte_cdk/sources/utils/slice_logger.py index ee802a7a6..4b29f3e0d 100644 --- a/airbyte_cdk/sources/utils/slice_logger.py +++ b/airbyte_cdk/sources/utils/slice_logger.py @@ -11,6 +11,10 @@ from airbyte_cdk.models import Type as MessageType +# Once everything runs on the concurrent CDK and we've cleaned up the legacy flows, we should try to remove +# this class and write messages directly to the message_repository instead of through the logger because for +# cases like the connector builder where ordering of messages is important, using the logger can cause +# messages to be grouped out of order. Alas work for a different day. class SliceLogger(ABC): """ SliceLogger is an interface that allows us to log slices of data in a uniform way. diff --git a/unit_tests/connector_builder/test_connector_builder_handler.py b/unit_tests/connector_builder/test_connector_builder_handler.py index 2587fb95a..993b42e8e 100644 --- a/unit_tests/connector_builder/test_connector_builder_handler.py +++ b/unit_tests/connector_builder/test_connector_builder_handler.py @@ -17,9 +17,6 @@ from airbyte_cdk import connector_builder from airbyte_cdk.connector_builder.connector_builder_handler import ( - DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE, - DEFAULT_MAXIMUM_NUMBER_OF_SLICES, - DEFAULT_MAXIMUM_RECORDS, TestLimits, create_source, get_limits, @@ -56,8 +53,11 @@ Type, ) from airbyte_cdk.models import Type as MessageType +from airbyte_cdk.sources.declarative.concurrent_declarative_source import ( + ConcurrentDeclarativeSource, + TestLimits, +) from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream -from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever from airbyte_cdk.sources.declarative.stream_slicers import StreamSlicerTestReadDecorator from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse @@ -530,7 +530,9 @@ def test_resolve_manifest(valid_resolve_manifest_config_file): config = copy.deepcopy(RESOLVE_MANIFEST_CONFIG) command = "resolve_manifest" config["__command"] = command - source = ManifestDeclarativeSource(source_config=MANIFEST) + source = ConcurrentDeclarativeSource( + catalog=None, config=config, state=None, source_config=MANIFEST + ) limits = TestLimits() resolved_manifest = handle_connector_builder_request( source, command, config, create_configured_catalog("dummy_stream"), _A_STATE, limits @@ -679,19 +681,21 @@ def test_resolve_manifest(valid_resolve_manifest_config_file): def test_resolve_manifest_error_returns_error_response(): - class MockManifestDeclarativeSource: + class MockConcurrentDeclarativeSource: @property def resolved_manifest(self): raise ValueError - source = MockManifestDeclarativeSource() + source = MockConcurrentDeclarativeSource() response = resolve_manifest(source) assert "Error resolving manifest" in response.trace.error.message def test_read(): config = TEST_READ_CONFIG - source = ManifestDeclarativeSource(source_config=MANIFEST) + source = ConcurrentDeclarativeSource( + catalog=None, config=config, state=None, source_config=MANIFEST + ) real_record = AirbyteRecordMessage( data={"id": "1234", "key": "value"}, emitted_at=1, stream=_stream_name @@ -780,7 +784,9 @@ def test_config_update() -> None: "client_secret": "a client secret", "refresh_token": "a refresh token", } - source = ManifestDeclarativeSource(source_config=manifest) + source = ConcurrentDeclarativeSource( + catalog=None, config=config, state=None, source_config=manifest + ) refresh_request_response = { "access_token": "an updated access token", @@ -817,7 +823,7 @@ def cursor_field(self): def name(self): return _stream_name - class MockManifestDeclarativeSource: + class MockConcurrentDeclarativeSource: def streams(self, config): return [MockDeclarativeStream()] @@ -839,7 +845,7 @@ def check_config_against_spec(self) -> Literal[False]: stack_trace = "a stack trace" mock_from_exception.return_value = stack_trace - source = MockManifestDeclarativeSource() + source = MockConcurrentDeclarativeSource() limits = TestLimits() response = read_stream( source, @@ -881,19 +887,22 @@ def test_handle_429_response(): config = TEST_READ_CONFIG limits = TestLimits() - source = create_source(config, limits) + catalog = ConfiguredAirbyteCatalogSerializer.load(CONFIGURED_CATALOG) + source = create_source(config=config, limits=limits, catalog=catalog, state=None) with patch("requests.Session.send", return_value=response) as mock_send: response = handle_connector_builder_request( source, "test_read", config, - ConfiguredAirbyteCatalogSerializer.load(CONFIGURED_CATALOG), + catalog, _A_PER_PARTITION_STATE, limits, ) - mock_send.assert_called_once() + # The test read will attempt a read for 5 partitions, and attempt 1 request + # each time that will not be retried + assert mock_send.call_count == 5 @pytest.mark.parametrize( @@ -945,7 +954,7 @@ def test_invalid_config_command(invalid_config_file, dummy_catalog): @pytest.fixture def manifest_declarative_source(): - return mock.Mock(spec=ManifestDeclarativeSource, autospec=True) + return mock.Mock(spec=ConcurrentDeclarativeSource, autospec=True) def create_mock_retriever(name, url_base, path): @@ -970,16 +979,16 @@ def create_mock_declarative_stream(http_stream): ( "test_no_test_read_config", {}, - DEFAULT_MAXIMUM_RECORDS, - DEFAULT_MAXIMUM_NUMBER_OF_SLICES, - DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE, + TestLimits.DEFAULT_MAX_RECORDS, + TestLimits.DEFAULT_MAX_SLICES, + TestLimits.DEFAULT_MAX_PAGES_PER_SLICE, ), ( "test_no_values_set", {"__test_read_config": {}}, - DEFAULT_MAXIMUM_RECORDS, - DEFAULT_MAXIMUM_NUMBER_OF_SLICES, - DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE, + TestLimits.DEFAULT_MAX_RECORDS, + TestLimits.DEFAULT_MAX_SLICES, + TestLimits.DEFAULT_MAX_PAGES_PER_SLICE, ), ( "test_values_are_set", @@ -1007,9 +1016,9 @@ def test_create_source(): config = {"__injected_declarative_manifest": MANIFEST} - source = create_source(config, limits) + source = create_source(config=config, limits=limits, catalog=None, state=None) - assert isinstance(source, ManifestDeclarativeSource) + assert isinstance(source, ConcurrentDeclarativeSource) assert source._constructor._limit_pages_fetched_per_slice == limits.max_pages_per_slice assert source._constructor._limit_slices_fetched == limits.max_slices assert source._constructor._disable_cache @@ -1101,7 +1110,7 @@ def test_read_source(mock_http_stream): config = {"__injected_declarative_manifest": MANIFEST} - source = create_source(config, limits) + source = create_source(config=config, limits=limits, catalog=catalog, state=None) output_data = read_stream(source, config, catalog, _A_PER_PARTITION_STATE, limits).record.data slices = output_data["slices"] @@ -1149,7 +1158,7 @@ def test_read_source_single_page_single_slice(mock_http_stream): config = {"__injected_declarative_manifest": MANIFEST} - source = create_source(config, limits) + source = create_source(config=config, limits=limits, catalog=catalog, state=None) output_data = read_stream(source, config, catalog, _A_PER_PARTITION_STATE, limits).record.data slices = output_data["slices"] @@ -1236,7 +1245,7 @@ def test_handle_read_external_requests(deployment_mode, url_base, expected_error test_manifest["streams"][0]["$parameters"]["url_base"] = url_base config = {"__injected_declarative_manifest": test_manifest} - source = create_source(config, limits) + source = create_source(config=config, limits=limits, catalog=catalog, state=None) with mock.patch.dict(os.environ, {"DEPLOYMENT_MODE": deployment_mode}, clear=False): output_data = read_stream( @@ -1266,13 +1275,13 @@ def test_handle_read_external_requests(deployment_mode, url_base, expected_error pytest.param( "CLOUD", "https://10.0.27.27/tokens/bearer", - "AirbyteTracedException", + "StreamThreadException", id="test_cloud_read_with_private_endpoint", ), pytest.param( "CLOUD", "http://unsecured.protocol/tokens/bearer", - "InvalidSchema", + "StreamThreadException", id="test_cloud_read_with_unsecured_endpoint", ), pytest.param( @@ -1332,7 +1341,7 @@ def test_handle_read_external_oauth_request(deployment_mode, token_url, expected ) config = {"__injected_declarative_manifest": test_manifest} - source = create_source(config, limits) + source = create_source(config=config, limits=limits, catalog=catalog, state=None) with mock.patch.dict(os.environ, {"DEPLOYMENT_MODE": deployment_mode}, clear=False): output_data = read_stream( @@ -1389,7 +1398,9 @@ def test_read_stream_exception_with_secrets(): def test_full_resolve_manifest(valid_resolve_manifest_config_file): config = copy.deepcopy(RESOLVE_DYNAMIC_STREAM_MANIFEST_CONFIG) command = config["__command"] - source = ManifestDeclarativeSource(source_config=DYNAMIC_STREAM_MANIFEST) + source = ConcurrentDeclarativeSource( + catalog=None, config=config, state=None, source_config=DYNAMIC_STREAM_MANIFEST + ) limits = TestLimits(max_streams=2) with HttpMocker() as http_mocker: http_mocker.get( diff --git a/unit_tests/connector_builder/test_message_grouper.py b/unit_tests/connector_builder/test_message_grouper.py index 6c4f11526..e79ee117c 100644 --- a/unit_tests/connector_builder/test_message_grouper.py +++ b/unit_tests/connector_builder/test_message_grouper.py @@ -307,126 +307,6 @@ def test_get_grouped_messages_with_logs(mock_entrypoint_read: Mock) -> None: assert actual_log == expected_logs[i] -@pytest.mark.parametrize( - "request_record_limit, max_record_limit, should_fail", - [ - pytest.param(1, 3, False, id="test_create_request_with_record_limit"), - pytest.param(3, 1, True, id="test_create_request_record_limit_exceeds_max"), - ], -) -@patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read") -def test_get_grouped_messages_record_limit( - mock_entrypoint_read: Mock, request_record_limit: int, max_record_limit: int, should_fail: bool -) -> None: - stream_name = "hashiras" - url = "https://demonslayers.com/api/v1/hashiras?era=taisho" - request = { - "headers": {"Content-Type": "application/json"}, - "method": "GET", - "body": {"content": '{"custom": "field"}'}, - } - response = { - "status_code": 200, - "headers": {"field": "value"}, - "body": {"content": '{"name": "field"}'}, - } - mock_source = make_mock_source( - mock_entrypoint_read, - iter( - [ - request_response_log_message(request, response, url, stream_name), - record_message(stream_name, {"name": "Shinobu Kocho"}), - record_message(stream_name, {"name": "Muichiro Tokito"}), - request_response_log_message(request, response, url, stream_name), - record_message(stream_name, {"name": "Mitsuri Kanroji"}), - ] - ), - ) - n_records = 2 - record_limit = min(request_record_limit, max_record_limit) - - api = TestReader(MAX_PAGES_PER_SLICE, MAX_SLICES, max_record_limit=max_record_limit) - # this is the call we expect to raise an exception - if should_fail: - with pytest.raises(ValueError): - api.run_test_read( - mock_source, - config=CONFIG, - configured_catalog=create_configured_catalog(stream_name), - stream_name=stream_name, - state=_NO_STATE, - record_limit=request_record_limit, - ) - else: - actual_response: StreamRead = api.run_test_read( - mock_source, - config=CONFIG, - configured_catalog=create_configured_catalog(stream_name), - stream_name=stream_name, - state=_NO_STATE, - record_limit=request_record_limit, - ) - single_slice = actual_response.slices[0] - total_records = 0 - for i, actual_page in enumerate(single_slice.pages): - total_records += len(actual_page.records) - assert total_records == min([record_limit, n_records]) - - assert (total_records >= max_record_limit) == actual_response.test_read_limit_reached - - -@pytest.mark.parametrize( - "max_record_limit", - [ - pytest.param(2, id="test_create_request_no_record_limit"), - pytest.param(1, id="test_create_request_no_record_limit_n_records_exceed_max"), - ], -) -@patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read") -def test_get_grouped_messages_default_record_limit( - mock_entrypoint_read: Mock, max_record_limit: int -) -> None: - stream_name = "hashiras" - url = "https://demonslayers.com/api/v1/hashiras?era=taisho" - request = { - "headers": {"Content-Type": "application/json"}, - "method": "GET", - "body": {"content": '{"custom": "field"}'}, - } - response = { - "status_code": 200, - "headers": {"field": "value"}, - "body": {"content": '{"name": "field"}'}, - } - mock_source = make_mock_source( - mock_entrypoint_read, - iter( - [ - request_response_log_message(request, response, url, stream_name), - record_message(stream_name, {"name": "Shinobu Kocho"}), - record_message(stream_name, {"name": "Muichiro Tokito"}), - request_response_log_message(request, response, url, stream_name), - record_message(stream_name, {"name": "Mitsuri Kanroji"}), - ] - ), - ) - n_records = 2 - - api = TestReader(MAX_PAGES_PER_SLICE, MAX_SLICES, max_record_limit=max_record_limit) - actual_response: StreamRead = api.run_test_read( - source=mock_source, - config=CONFIG, - configured_catalog=create_configured_catalog(stream_name), - stream_name=stream_name, - state=_NO_STATE, - ) - single_slice = actual_response.slices[0] - total_records = 0 - for i, actual_page in enumerate(single_slice.pages): - total_records += len(actual_page.records) - assert total_records == min([max_record_limit, n_records]) - - @patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read") def test_get_grouped_messages_limit_0(mock_entrypoint_read: Mock) -> None: stream_name = "hashiras" diff --git a/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py b/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py index 13d1194dd..8637089a9 100644 --- a/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py +++ b/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py @@ -3614,7 +3614,14 @@ def test_given_no_partitions_processed_when_close_partition_then_no_state_update slices = list(cursor.stream_slices()) # Call once for slice in slices: cursor.close_partition( - DeclarativePartition("test_stream", {}, MagicMock(), MagicMock(), slice) + DeclarativePartition( + stream_name="test_stream", + json_schema={}, + retriever=MagicMock(), + message_repository=MagicMock(), + max_records_limit=None, + stream_slice=slice, + ) ) assert cursor.state == { @@ -3692,7 +3699,14 @@ def test_given_unfinished_first_parent_partition_no_parent_state_update(): # Close all partitions except from the first one for slice in slices[1:]: cursor.close_partition( - DeclarativePartition("test_stream", {}, MagicMock(), MagicMock(), slice) + DeclarativePartition( + stream_name="test_stream", + json_schema={}, + retriever=MagicMock(), + message_repository=MagicMock(), + max_records_limit=None, + stream_slice=slice, + ) ) cursor.ensure_at_least_one_state_emitted() @@ -3780,7 +3794,14 @@ def test_given_unfinished_last_parent_partition_with_partial_parent_state_update # Close all partitions except from the first one for slice in slices[:-1]: cursor.close_partition( - DeclarativePartition("test_stream", {}, MagicMock(), MagicMock(), slice) + DeclarativePartition( + stream_name="test_stream", + json_schema={}, + retriever=MagicMock(), + message_repository=MagicMock(), + max_records_limit=None, + stream_slice=slice, + ) ) cursor.ensure_at_least_one_state_emitted() @@ -3863,7 +3884,14 @@ def test_given_all_partitions_finished_when_close_partition_then_final_state_emi slices = list(cursor.stream_slices()) for slice in slices: cursor.close_partition( - DeclarativePartition("test_stream", {}, MagicMock(), MagicMock(), slice) + DeclarativePartition( + stream_name="test_stream", + json_schema={}, + retriever=MagicMock(), + message_repository=MagicMock(), + max_records_limit=None, + stream_slice=slice, + ) ) cursor.ensure_at_least_one_state_emitted() @@ -3930,7 +3958,14 @@ def test_given_partition_limit_exceeded_when_close_partition_then_switch_to_glob slices = list(cursor.stream_slices()) for slice in slices: cursor.close_partition( - DeclarativePartition("test_stream", {}, MagicMock(), MagicMock(), slice) + DeclarativePartition( + stream_name="test_stream", + json_schema={}, + retriever=MagicMock(), + message_repository=MagicMock(), + max_records_limit=None, + stream_slice=slice, + ) ) cursor.ensure_at_least_one_state_emitted() @@ -4007,7 +4042,16 @@ def test_semaphore_cleanup(): # Close partitions to acquire semaphores (value back to 0) for s in generated_slices: - cursor.close_partition(DeclarativePartition("test_stream", {}, MagicMock(), MagicMock(), s)) + cursor.close_partition( + DeclarativePartition( + stream_name="test_stream", + json_schema={}, + retriever=MagicMock(), + message_repository=MagicMock(), + max_records_limit=None, + stream_slice=s, + ) + ) # Check state after closing partitions assert len(cursor._partitions_done_generating_stream_slices) == 0 @@ -4119,15 +4163,38 @@ def test_duplicate_partition_after_closing_partition_cursor_deleted(): first_1 = next(slice_gen) cursor.close_partition( - DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), first_1) + DeclarativePartition( + stream_name="dup_stream", + json_schema={}, + retriever=MagicMock(), + message_repository=MagicMock(), + max_records_limit=None, + stream_slice=first_1, + ) ) two = next(slice_gen) - cursor.close_partition(DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), two)) + cursor.close_partition( + DeclarativePartition( + stream_name="dup_stream", + json_schema={}, + retriever=MagicMock(), + message_repository=MagicMock(), + max_records_limit=None, + stream_slice=two, + ) + ) second_1 = next(slice_gen) cursor.close_partition( - DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), second_1) + DeclarativePartition( + stream_name="dup_stream", + json_schema={}, + retriever=MagicMock(), + message_repository=MagicMock(), + max_records_limit=None, + stream_slice=second_1, + ) ) assert cursor._IS_PARTITION_DUPLICATION_LOGGED is False # No duplicate detected @@ -4181,16 +4248,39 @@ def test_duplicate_partition_after_closing_partition_cursor_exists(): first_1 = next(slice_gen) cursor.close_partition( - DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), first_1) + DeclarativePartition( + stream_name="dup_stream", + json_schema={}, + retriever=MagicMock(), + message_repository=MagicMock(), + max_records_limit=None, + stream_slice=first_1, + ) ) two = next(slice_gen) - cursor.close_partition(DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), two)) + cursor.close_partition( + DeclarativePartition( + stream_name="dup_stream", + json_schema={}, + retriever=MagicMock(), + message_repository=MagicMock(), + max_records_limit=None, + stream_slice=two, + ) + ) # Second “1” should appear because the semaphore was cleaned up second_1 = next(slice_gen) cursor.close_partition( - DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), second_1) + DeclarativePartition( + stream_name="dup_stream", + json_schema={}, + retriever=MagicMock(), + message_repository=MagicMock(), + max_records_limit=None, + stream_slice=second_1, + ) ) with pytest.raises(StopIteration): @@ -4241,11 +4331,25 @@ def test_duplicate_partition_while_processing(): # Close “2” first cursor.close_partition( - DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), generated[1]) + DeclarativePartition( + stream_name="dup_stream", + json_schema={}, + retriever=MagicMock(), + message_repository=MagicMock(), + max_records_limit=None, + stream_slice=generated[1], + ) ) # Now close the initial “1” cursor.close_partition( - DeclarativePartition("dup_stream", {}, MagicMock(), MagicMock(), generated[0]) + DeclarativePartition( + stream_name="dup_stream", + json_schema={}, + retriever=MagicMock(), + message_repository=MagicMock(), + max_records_limit=None, + stream_slice=generated[0], + ) ) assert cursor._IS_PARTITION_DUPLICATION_LOGGED is True # warning emitted diff --git a/unit_tests/sources/declarative/retrievers/test_simple_retriever.py b/unit_tests/sources/declarative/retrievers/test_simple_retriever.py index a1e390177..06b46bfcc 100644 --- a/unit_tests/sources/declarative/retrievers/test_simple_retriever.py +++ b/unit_tests/sources/declarative/retrievers/test_simple_retriever.py @@ -11,7 +11,14 @@ import requests from airbyte_cdk import YamlDeclarativeSource -from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, SyncMode, Type +from airbyte_cdk.models import ( + AirbyteLogMessage, + AirbyteMessage, + AirbyteRecordMessage, + Level, + SyncMode, + Type, +) from airbyte_cdk.sources.declarative.auth.declarative_authenticator import NoAuth from airbyte_cdk.sources.declarative.decoders import JsonDecoder from airbyte_cdk.sources.declarative.extractors import DpathExtractor, RecordSelector diff --git a/unit_tests/sources/declarative/schema/test_dynamic_schema_loader.py b/unit_tests/sources/declarative/schema/test_dynamic_schema_loader.py index 97f89879c..20147465f 100644 --- a/unit_tests/sources/declarative/schema/test_dynamic_schema_loader.py +++ b/unit_tests/sources/declarative/schema/test_dynamic_schema_loader.py @@ -10,9 +10,7 @@ from airbyte_cdk.sources.declarative.concurrent_declarative_source import ( ConcurrentDeclarativeSource, -) -from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ( - ModelToComponentFactory, + TestLimits, ) from airbyte_cdk.sources.declarative.schema import DynamicSchemaLoader, SchemaTypeIdentifier from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse @@ -353,14 +351,13 @@ def test_dynamic_schema_loader_with_type_conditions(): }, }, } + source = ConcurrentDeclarativeSource( source_config=_MANIFEST_WITH_TYPE_CONDITIONS, config=_CONFIG, catalog=None, state=None, - component_factory=ModelToComponentFactory( - disable_cache=True - ), # Avoid caching on the HttpClient which could result in caching the requests/responses of other tests + limits=TestLimits(), # Avoid caching on the HttpClient which could result in caching the requests/responses of other tests ) with HttpMocker() as http_mocker: http_mocker.get( diff --git a/unit_tests/sources/declarative/stream_slicers/test_declarative_partition_generator.py b/unit_tests/sources/declarative/stream_slicers/test_declarative_partition_generator.py index 3ced03a69..9bab0f56f 100644 --- a/unit_tests/sources/declarative/stream_slicers/test_declarative_partition_generator.py +++ b/unit_tests/sources/declarative/stream_slicers/test_declarative_partition_generator.py @@ -4,14 +4,13 @@ from unittest import TestCase from unittest.mock import Mock +# This allows for the global total_record_counter to be reset between tests +import airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator as declarative_partition_generator from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type from airbyte_cdk.sources.declarative.retrievers import Retriever -from airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator import ( - DeclarativePartitionFactory, -) from airbyte_cdk.sources.message import MessageRepository from airbyte_cdk.sources.streams.core import StreamData -from airbyte_cdk.sources.types import StreamSlice +from airbyte_cdk.sources.types import Record, StreamSlice _STREAM_NAME = "a_stream_name" _JSON_SCHEMA = {"type": "object", "properties": {}} @@ -32,7 +31,7 @@ class StreamSlicerPartitionGeneratorTest(TestCase): def test_given_multiple_slices_partition_generator_uses_the_same_retriever(self) -> None: retriever = self._mock_retriever([]) message_repository = Mock(spec=MessageRepository) - partition_factory = DeclarativePartitionFactory( + partition_factory = declarative_partition_generator.DeclarativePartitionFactory( _STREAM_NAME, _JSON_SCHEMA, retriever, @@ -47,7 +46,7 @@ def test_given_multiple_slices_partition_generator_uses_the_same_retriever(self) def test_given_a_mapping_when_read_then_yield_record(self) -> None: retriever = self._mock_retriever([_A_RECORD]) message_repository = Mock(spec=MessageRepository) - partition_factory = DeclarativePartitionFactory( + partition_factory = declarative_partition_generator.DeclarativePartitionFactory( _STREAM_NAME, _JSON_SCHEMA, retriever, @@ -65,7 +64,7 @@ def test_given_a_mapping_when_read_then_yield_record(self) -> None: def test_given_not_a_record_when_read_then_send_to_message_repository(self) -> None: retriever = self._mock_retriever([_AIRBYTE_LOG_MESSAGE]) message_repository = Mock(spec=MessageRepository) - partition_factory = DeclarativePartitionFactory( + partition_factory = declarative_partition_generator.DeclarativePartitionFactory( _STREAM_NAME, _JSON_SCHEMA, retriever, @@ -76,6 +75,78 @@ def test_given_not_a_record_when_read_then_send_to_message_repository(self) -> N message_repository.emit_message.assert_called_once_with(_AIRBYTE_LOG_MESSAGE) + def test_max_records_reached_stops_reading(self) -> None: + declarative_partition_generator.total_record_counter = 0 + + expected_records = [ + Record(data={"id": 1, "name": "Max"}, stream_name="stream_name"), + Record(data={"id": 1, "name": "Oscar"}, stream_name="stream_name"), + Record(data={"id": 1, "name": "Charles"}, stream_name="stream_name"), + Record(data={"id": 1, "name": "Alex"}, stream_name="stream_name"), + Record(data={"id": 1, "name": "Yuki"}, stream_name="stream_name"), + ] + + mock_records = expected_records + [ + Record(data={"id": 1, "name": "Lewis"}, stream_name="stream_name"), + Record(data={"id": 1, "name": "Lando"}, stream_name="stream_name"), + ] + + retriever = self._mock_retriever(mock_records) + message_repository = Mock(spec=MessageRepository) + partition_factory = declarative_partition_generator.DeclarativePartitionFactory( + _STREAM_NAME, + _JSON_SCHEMA, + retriever, + message_repository, + max_records_limit=5, + ) + + partition = partition_factory.create(_A_STREAM_SLICE) + + actual_records = list(partition.read()) + + assert len(actual_records) == 5 + assert actual_records == expected_records + + def test_max_records_reached_on_previous_partition(self) -> None: + declarative_partition_generator.total_record_counter = 0 + + expected_records = [ + Record(data={"id": 1, "name": "Max"}, stream_name="stream_name"), + Record(data={"id": 1, "name": "Oscar"}, stream_name="stream_name"), + Record(data={"id": 1, "name": "Charles"}, stream_name="stream_name"), + ] + + mock_records = expected_records + [ + Record(data={"id": 1, "name": "Alex"}, stream_name="stream_name"), + Record(data={"id": 1, "name": "Yuki"}, stream_name="stream_name"), + ] + + retriever = self._mock_retriever(mock_records) + message_repository = Mock(spec=MessageRepository) + partition_factory = declarative_partition_generator.DeclarativePartitionFactory( + _STREAM_NAME, + _JSON_SCHEMA, + retriever, + message_repository, + max_records_limit=3, + ) + + partition = partition_factory.create(_A_STREAM_SLICE) + + first_partition_records = list(partition.read()) + + assert len(first_partition_records) == 3 + assert first_partition_records == expected_records + + second_partition_records = list(partition.read()) + assert len(second_partition_records) == 0 + + # The DeclarativePartition exits out of the read before attempting to read_records() if + # the max_records_limit has already been reached. So we only expect to see read_records() + # called for the first partition read and not the second + retriever.read_records.assert_called_once() + @staticmethod def _mock_retriever(read_return_value: List[StreamData]) -> Mock: retriever = Mock(spec=Retriever) diff --git a/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py b/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py index 50695ba1e..75b52f6b2 100644 --- a/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py +++ b/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py @@ -50,7 +50,10 @@ def __init__( self._message_repository = InMemoryMessageRepository() threadpool_manager = ThreadPoolManager(threadpool, streams[0].logger) concurrent_source = ConcurrentSource( - threadpool_manager, streams[0].logger, NeverLogSliceLogger(), self._message_repository + threadpool=threadpool_manager, + logger=streams[0].logger, + slice_logger=NeverLogSliceLogger(), + message_repository=self._message_repository, ) super().__init__(concurrent_source) self._streams = streams diff --git a/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py b/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py index d6ea64583..a681f75eb 100644 --- a/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py +++ b/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py @@ -176,10 +176,12 @@ def test_handle_partition(self): self._partition_reader, ) + expected_cursor = handler._stream_name_to_instance[_ANOTHER_STREAM_NAME].cursor + handler.on_partition(self._a_closed_partition) self._thread_pool_manager.submit.assert_called_with( - self._partition_reader.process_partition, self._a_closed_partition + self._partition_reader.process_partition, self._a_closed_partition, expected_cursor ) assert ( self._a_closed_partition in handler._streams_to_running_partitions[_ANOTHER_STREAM_NAME] @@ -201,10 +203,12 @@ def test_handle_partition_emits_log_message_if_it_should_be_logged(self): self._partition_reader, ) + expected_cursor = handler._stream_name_to_instance[_STREAM_NAME].cursor + handler.on_partition(self._an_open_partition) self._thread_pool_manager.submit.assert_called_with( - self._partition_reader.process_partition, self._an_open_partition + self._partition_reader.process_partition, self._an_open_partition, expected_cursor ) self._message_repository.emit_message.assert_called_with(self._log_message) @@ -253,8 +257,6 @@ def test_handle_on_partition_complete_sentinel_with_messages_from_repository(sel ] assert messages == expected_messages - self._stream.cursor.close_partition.assert_called_once() - @freezegun.freeze_time("2020-01-01T00:00:00") def test_handle_on_partition_complete_sentinel_yields_status_message_if_the_stream_is_done( self, @@ -302,55 +304,6 @@ def test_handle_on_partition_complete_sentinel_yields_status_message_if_the_stre ) ] assert messages == expected_messages - self._another_stream.cursor.close_partition.assert_called_once() - - @freezegun.freeze_time("2020-01-01T00:00:00") - def test_given_exception_on_partition_complete_sentinel_then_yield_error_trace_message_and_stream_is_incomplete( - self, - ) -> None: - self._a_closed_partition.stream_name.return_value = self._stream.name - self._stream.cursor.close_partition.side_effect = ValueError - - handler = ConcurrentReadProcessor( - [self._stream], - self._partition_enqueuer, - self._thread_pool_manager, - self._logger, - self._slice_logger, - self._message_repository, - self._partition_reader, - ) - handler.start_next_partition_generator() - handler.on_partition(self._a_closed_partition) - list( - handler.on_partition_generation_completed( - PartitionGenerationCompletedSentinel(self._stream) - ) - ) - messages = list( - handler.on_partition_complete_sentinel( - PartitionCompleteSentinel(self._a_closed_partition) - ) - ) - - expected_status_message = AirbyteMessage( - type=MessageType.TRACE, - trace=AirbyteTraceMessage( - type=TraceType.STREAM_STATUS, - stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor( - name=self._stream.name, - ), - status=AirbyteStreamStatus.INCOMPLETE, - ), - emitted_at=1577836800000.0, - ), - ) - assert list(map(lambda message: message.trace.type, messages)) == [ - TraceType.ERROR, - TraceType.STREAM_STATUS, - ] - assert messages[1] == expected_status_message @freezegun.freeze_time("2020-01-01T00:00:00") def test_handle_on_partition_complete_sentinel_yields_no_status_message_if_the_stream_is_not_done( @@ -379,7 +332,6 @@ def test_handle_on_partition_complete_sentinel_yields_no_status_message_if_the_s expected_messages = [] assert messages == expected_messages - self._stream.cursor.close_partition.assert_called_once() @freezegun.freeze_time("2020-01-01T00:00:00") def test_on_record_no_status_message_no_repository_messge(self): diff --git a/unit_tests/sources/streams/concurrent/test_partition_reader.py b/unit_tests/sources/streams/concurrent/test_partition_reader.py index 1910e034d..a41750772 100644 --- a/unit_tests/sources/streams/concurrent/test_partition_reader.py +++ b/unit_tests/sources/streams/concurrent/test_partition_reader.py @@ -1,6 +1,5 @@ -# -# Copyright (c) 2023 Airbyte, Inc., all rights reserved. -# +# Copyright (c) 2025 Airbyte, Inc., all rights reserved. + import unittest from queue import Queue from typing import Callable, Iterable, List @@ -8,7 +7,9 @@ import pytest +from airbyte_cdk import InMemoryMessageRepository from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException +from airbyte_cdk.sources.streams.concurrent.cursor import FinalStateCursor from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionReader from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.types import ( @@ -26,10 +27,15 @@ class PartitionReaderTest(unittest.TestCase): def setUp(self) -> None: self._queue: Queue[QueueItem] = Queue() - self._partition_reader = PartitionReader(self._queue) + self._partition_reader = PartitionReader(self._queue, None) def test_given_no_records_when_process_partition_then_only_emit_sentinel(self): - self._partition_reader.process_partition(self._a_partition([])) + cursor = FinalStateCursor( + stream_name="test", + stream_namespace=None, + message_repository=InMemoryMessageRepository(), + ) + self._partition_reader.process_partition(self._a_partition([]), cursor) while queue_item := self._queue.get(): if not isinstance(queue_item, PartitionCompleteSentinel): @@ -40,19 +46,24 @@ def test_given_read_partition_successful_when_process_partition_then_queue_recor self, ): partition = self._a_partition(_RECORDS) - self._partition_reader.process_partition(partition) + cursor = Mock() + self._partition_reader.process_partition(partition, cursor) queue_content = self._consume_queue() assert queue_content == _RECORDS + [PartitionCompleteSentinel(partition)] - def test_given_exception_when_process_partition_then_queue_records_and_exception_and_sentinel( + cursor.observe.assert_called() + cursor.close_partition.assert_called_once() + + def test_given_exception_from_read_when_process_partition_then_queue_records_and_exception_and_sentinel( self, ): partition = Mock() + cursor = Mock() exception = ValueError() partition.read.side_effect = self._read_with_exception(_RECORDS, exception) - self._partition_reader.process_partition(partition) + self._partition_reader.process_partition(partition, cursor) queue_content = self._consume_queue() @@ -61,6 +72,23 @@ def test_given_exception_when_process_partition_then_queue_records_and_exception PartitionCompleteSentinel(partition), ] + def test_given_exception_from_close_slice_when_process_partition_then_queue_records_and_exception_and_sentinel( + self, + ): + partition = self._a_partition(_RECORDS) + cursor = Mock() + exception = ValueError() + cursor.close_partition.side_effect = self._close_partition_with_exception(exception) + self._partition_reader.process_partition(partition, cursor) + + queue_content = self._consume_queue() + + # 4 total messages in queue. 2 records, 1 thread exception, 1 partition sentinel value + assert len(queue_content) == 4 + assert queue_content[:2] == _RECORDS + assert isinstance(queue_content[2], StreamThreadException) + assert queue_content[3] == PartitionCompleteSentinel(partition) + def _a_partition(self, records: List[Record]) -> Partition: partition = Mock(spec=Partition) partition.read.return_value = iter(records) @@ -76,6 +104,13 @@ def mocked_function() -> Iterable[Record]: return mocked_function + @staticmethod + def _close_partition_with_exception(exception: Exception) -> Callable[[Partition], None]: + def mocked_function(partition: Partition) -> None: + raise exception + + return mocked_function + def _consume_queue(self): queue_content = [] while queue_item := self._queue.get(): From f73324d4bbc3f0aff7a25b3209879496848c1462 Mon Sep 17 00:00:00 2001 From: brianjlai Date: Tue, 19 Aug 2025 12:54:20 -0700 Subject: [PATCH 02/17] add pytest env to cache name to enforce unique cache per test file --- airbyte_cdk/sources/streams/http/http_client.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/airbyte_cdk/sources/streams/http/http_client.py b/airbyte_cdk/sources/streams/http/http_client.py index c4fa86866..e10c95890 100644 --- a/airbyte_cdk/sources/streams/http/http_client.py +++ b/airbyte_cdk/sources/streams/http/http_client.py @@ -127,6 +127,12 @@ def cache_filename(self) -> str: Override if needed. Return the name of cache file Note that if the environment variable REQUEST_CACHE_PATH is not set, the cache will be in-memory only. """ + # This is a hack so that we ensure that the same cache is not used across different test files + # because we observed some flakiness in tests when running on CI + # https://github.com/airbytehq/airbyte-python-cdk/pull/688 + # https://github.com/airbytehq/airbyte-python-cdk/pull/712 + if os.getenv("PYTEST_CURRENT_TEST"): + return f"{self._name}-{os.getenv('PYTEST_CURRENT_TEST')}.sqlite" return f"{self._name}.sqlite" def _request_session(self) -> requests.Session: @@ -153,7 +159,10 @@ def _request_session(self) -> requests.Session: # * `If the application running SQLite crashes, the data will be safe, but the database [might become corrupted](https://www.sqlite.org/howtocorrupt.html#cfgerr) if the operating system crashes or the computer loses power before that data has been written to the disk surface.` in [this description](https://www.sqlite.org/pragma.html#pragma_synchronous). backend = requests_cache.SQLiteCache(sqlite_path, fast_save=True, wal=True) return CachedLimiterSession( - sqlite_path, backend=backend, api_budget=self._api_budget, match_headers=True + cache_name=sqlite_path, + backend=backend, + api_budget=self._api_budget, + match_headers=True, ) else: return LimiterSession(api_budget=self._api_budget) From 8a971f3b795e28bf03f420abc4f3aea62db97a00 Mon Sep 17 00:00:00 2001 From: brianjlai Date: Tue, 19 Aug 2025 15:17:37 -0700 Subject: [PATCH 03/17] adding lots of logging to diagnose flaky test --- .../concurrent_read_processor.py | 4 +++ .../sources/message/concurrent_repository.py | 30 ++++++++++++++++++- .../test_concurrent_declarative_source.py | 2 ++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py index 33731e74c..1bd6d4b46 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py @@ -2,6 +2,7 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # import logging +import os from typing import Dict, Iterable, List, Optional, Set from airbyte_cdk.exception_handler import generate_failed_streams_error_message @@ -154,6 +155,9 @@ def on_record(self, record: Record) -> Iterable[AirbyteMessage]: ) self._record_counter[stream.name] += 1 stream.cursor.observe(record) + test_env = os.getenv("PYTEST_CURRENT_TEST") + if test_env and "test_concurrent_declarative_source.py" in test_env: + self._logger.info(f"Processing and emitting: {message.__dict__}") yield message yield from self._message_repository.consume_queue() diff --git a/airbyte_cdk/sources/message/concurrent_repository.py b/airbyte_cdk/sources/message/concurrent_repository.py index 947ee4c46..2b683a49f 100644 --- a/airbyte_cdk/sources/message/concurrent_repository.py +++ b/airbyte_cdk/sources/message/concurrent_repository.py @@ -1,12 +1,16 @@ # Copyright (c) 2025 Airbyte, Inc., all rights reserved. - +import logging +import os from queue import Queue from typing import Callable, Iterable from airbyte_cdk.models import AirbyteMessage, Level +from airbyte_cdk.models import Type as MessageType from airbyte_cdk.sources.message.repository import LogMessage, MessageRepository from airbyte_cdk.sources.streams.concurrent.partitions.types import QueueItem +logger = logging.getLogger("airbyte") + class ConcurrentMessageRepository(MessageRepository): """ @@ -25,14 +29,25 @@ def __init__(self, queue: Queue[QueueItem], message_repository: MessageRepositor self._queue = queue self._decorated_message_repository = message_repository + test_env = os.getenv("PYTEST_CURRENT_TEST") + self._log_messages_for_testing = ( + test_env and "test_concurrent_declarative_source.py" in test_env + ) + def emit_message(self, message: AirbyteMessage) -> None: + if self._log_messages_for_testing: + self._log_message(message) self._decorated_message_repository.emit_message(message) for message in self._decorated_message_repository.consume_queue(): + if self._log_messages_for_testing: + self._log_message(message) self._queue.put(message) def log_message(self, level: Level, message_provider: Callable[[], LogMessage]) -> None: self._decorated_message_repository.log_message(level, message_provider) for message in self._decorated_message_repository.consume_queue(): + if self._log_messages_for_testing: + self._log_message(message) self._queue.put(message) def consume_queue(self) -> Iterable[AirbyteMessage]: @@ -41,3 +56,16 @@ def consume_queue(self) -> Iterable[AirbyteMessage]: loading messages onto the queue processed on the main thread. """ yield from [] + + @staticmethod + def _log_message(message: AirbyteMessage) -> None: + if message.type == MessageType.STATE: + if message.state and message.state.stream: + state = message.state.stream.stream_state.__dict__ + logger.info( + f"Processing and emitting message of type {message.type} with contents: {message.state.stream.stream_state.__dict__}" + ) + else: + logger.info( + f"Processing and emitting message of type {message.type} with contents: {message.__dict__}" + ) diff --git a/unit_tests/sources/declarative/test_concurrent_declarative_source.py b/unit_tests/sources/declarative/test_concurrent_declarative_source.py index bdfefbd80..4279e53c6 100644 --- a/unit_tests/sources/declarative/test_concurrent_declarative_source.py +++ b/unit_tests/sources/declarative/test_concurrent_declarative_source.py @@ -1507,6 +1507,8 @@ def test_read_concurrent_with_failing_partition_in_the_middle(): ): messages.append(message) except AirbyteTracedException: + locations_states = get_states_for_stream(stream_name="locations", messages=messages) + assert len(locations_states) == 3 assert ( get_states_for_stream(stream_name="locations", messages=messages)[ -1 From 39ba730eb1f55711b36f157cd043f38ba4307fdd Mon Sep 17 00:00:00 2001 From: brianjlai Date: Tue, 19 Aug 2025 15:39:44 -0700 Subject: [PATCH 04/17] more detailed logging --- .../sources/message/concurrent_repository.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/airbyte_cdk/sources/message/concurrent_repository.py b/airbyte_cdk/sources/message/concurrent_repository.py index 2b683a49f..1ed07b9ff 100644 --- a/airbyte_cdk/sources/message/concurrent_repository.py +++ b/airbyte_cdk/sources/message/concurrent_repository.py @@ -35,19 +35,17 @@ def __init__(self, queue: Queue[QueueItem], message_repository: MessageRepositor ) def emit_message(self, message: AirbyteMessage) -> None: - if self._log_messages_for_testing: - self._log_message(message) self._decorated_message_repository.emit_message(message) for message in self._decorated_message_repository.consume_queue(): if self._log_messages_for_testing: - self._log_message(message) + self._log_message(message, "emit_message()") self._queue.put(message) def log_message(self, level: Level, message_provider: Callable[[], LogMessage]) -> None: self._decorated_message_repository.log_message(level, message_provider) for message in self._decorated_message_repository.consume_queue(): if self._log_messages_for_testing: - self._log_message(message) + self._log_message(message, "log_message()") self._queue.put(message) def consume_queue(self) -> Iterable[AirbyteMessage]: @@ -58,14 +56,15 @@ def consume_queue(self) -> Iterable[AirbyteMessage]: yield from [] @staticmethod - def _log_message(message: AirbyteMessage) -> None: + def _log_message(message: AirbyteMessage, calling_method: str) -> None: if message.type == MessageType.STATE: if message.state and message.state.stream: + stream_name = message.state.stream.stream_descriptor.name state = message.state.stream.stream_state.__dict__ logger.info( - f"Processing and emitting message of type {message.type} with contents: {message.state.stream.stream_state.__dict__}" + f"From {calling_method} -- emitting message of type {message.type} for stream {stream_name} with contents: {state}" ) else: logger.info( - f"Processing and emitting message of type {message.type} with contents: {message.__dict__}" + f"From {calling_method} -- emitting message of type {message.type} with contents: {message.__dict__}" ) From e7b2cd7ec1fef56b396af3761ea1aae75c8ac08b Mon Sep 17 00:00:00 2001 From: "maxime.c" Date: Tue, 19 Aug 2025 22:54:48 -0400 Subject: [PATCH 05/17] log every item in the queue --- .../sources/concurrent_source/concurrent_read_processor.py | 3 --- airbyte_cdk/sources/concurrent_source/concurrent_source.py | 4 ++++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py index 1bd6d4b46..bbc359a58 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py @@ -155,9 +155,6 @@ def on_record(self, record: Record) -> Iterable[AirbyteMessage]: ) self._record_counter[stream.name] += 1 stream.cursor.observe(record) - test_env = os.getenv("PYTEST_CURRENT_TEST") - if test_env and "test_concurrent_declarative_source.py" in test_env: - self._logger.info(f"Processing and emitting: {message.__dict__}") yield message yield from self._message_repository.consume_queue() diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_source.py b/airbyte_cdk/sources/concurrent_source/concurrent_source.py index 9ccfc1088..855b9e20a 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_source.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_source.py @@ -3,6 +3,7 @@ # import concurrent import logging +import os from queue import Queue from typing import Iterable, Iterator, List, Optional @@ -143,6 +144,9 @@ def _consume_from_queue( concurrent_stream_processor: ConcurrentReadProcessor, ) -> Iterable[AirbyteMessage]: while airbyte_message_or_record_or_exception := queue.get(): + test_env = os.getenv("PYTEST_CURRENT_TEST") + if test_env and "test_concurrent_declarative_source.py" in test_env: + self._logger.info(f"Processing and emitting: {airbyte_message_or_record_or_exception.__dict__}") yield from self._handle_item( airbyte_message_or_record_or_exception, concurrent_stream_processor, From 6d75a9262624e44a87b2b0f6a163485c0fda4a61 Mon Sep 17 00:00:00 2001 From: "maxime.c" Date: Tue, 19 Aug 2025 22:57:22 -0400 Subject: [PATCH 06/17] add type of queueu item to log message --- airbyte_cdk/sources/concurrent_source/concurrent_source.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_source.py b/airbyte_cdk/sources/concurrent_source/concurrent_source.py index 855b9e20a..3e84c0367 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_source.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_source.py @@ -146,7 +146,7 @@ def _consume_from_queue( while airbyte_message_or_record_or_exception := queue.get(): test_env = os.getenv("PYTEST_CURRENT_TEST") if test_env and "test_concurrent_declarative_source.py" in test_env: - self._logger.info(f"Processing and emitting: {airbyte_message_or_record_or_exception.__dict__}") + self._logger.info(f"Processing and emitting {type(airbyte_message_or_record_or_exception)}: {airbyte_message_or_record_or_exception.__dict__}") yield from self._handle_item( airbyte_message_or_record_or_exception, concurrent_stream_processor, From 1c3975a50baf00a610d75322e0a67dd45638c725 Mon Sep 17 00:00:00 2001 From: "maxime.c" Date: Tue, 19 Aug 2025 23:40:51 -0400 Subject: [PATCH 07/17] add logging to concurrent cursor --- airbyte_cdk/sources/streams/concurrent/cursor.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/airbyte_cdk/sources/streams/concurrent/cursor.py b/airbyte_cdk/sources/streams/concurrent/cursor.py index 318076835..1eeca47c4 100644 --- a/airbyte_cdk/sources/streams/concurrent/cursor.py +++ b/airbyte_cdk/sources/streams/concurrent/cursor.py @@ -4,6 +4,7 @@ import functools import logging +import os from abc import ABC, abstractmethod from typing import ( Any, @@ -237,12 +238,18 @@ def _extract_cursor_value(self, record: Record) -> Any: return self._connector_state_converter.parse_value(self._cursor_field.extract_value(record)) def close_partition(self, partition: Partition) -> None: + test_env = os.getenv("PYTEST_CURRENT_TEST") + if test_env and "test_concurrent_declarative_source.py" in test_env: + LOGGER.info(f"Closing partition {partition.to_slice()}") + LOGGER.info(f"\tstate before is {self._concurrent_state}") slice_count_before = len(self._concurrent_state.get("slices", [])) self._add_slice_to_state(partition) if slice_count_before < len( self._concurrent_state["slices"] ): # only emit if at least one slice has been processed self._merge_partitions() + if test_env and "test_concurrent_declarative_source.py" in test_env: + LOGGER.info(f"\tstate after merged partition is {self._concurrent_state}") self._emit_state_message() self._has_closed_at_least_one_slice = True From 5d41f4ee5f539c1a37a2089c7a4db65b56347f0e Mon Sep 17 00:00:00 2001 From: "maxime.c" Date: Wed, 20 Aug 2025 12:05:08 -0400 Subject: [PATCH 08/17] Revert "add logging to concurrent cursor" This reverts commit 1c3975a50baf00a610d75322e0a67dd45638c725. --- airbyte_cdk/sources/streams/concurrent/cursor.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/airbyte_cdk/sources/streams/concurrent/cursor.py b/airbyte_cdk/sources/streams/concurrent/cursor.py index 1eeca47c4..318076835 100644 --- a/airbyte_cdk/sources/streams/concurrent/cursor.py +++ b/airbyte_cdk/sources/streams/concurrent/cursor.py @@ -4,7 +4,6 @@ import functools import logging -import os from abc import ABC, abstractmethod from typing import ( Any, @@ -238,18 +237,12 @@ def _extract_cursor_value(self, record: Record) -> Any: return self._connector_state_converter.parse_value(self._cursor_field.extract_value(record)) def close_partition(self, partition: Partition) -> None: - test_env = os.getenv("PYTEST_CURRENT_TEST") - if test_env and "test_concurrent_declarative_source.py" in test_env: - LOGGER.info(f"Closing partition {partition.to_slice()}") - LOGGER.info(f"\tstate before is {self._concurrent_state}") slice_count_before = len(self._concurrent_state.get("slices", [])) self._add_slice_to_state(partition) if slice_count_before < len( self._concurrent_state["slices"] ): # only emit if at least one slice has been processed self._merge_partitions() - if test_env and "test_concurrent_declarative_source.py" in test_env: - LOGGER.info(f"\tstate after merged partition is {self._concurrent_state}") self._emit_state_message() self._has_closed_at_least_one_slice = True From c934da123d9ad4a344c02bf7ec62de51423bbde6 Mon Sep 17 00:00:00 2001 From: brianjlai Date: Wed, 20 Aug 2025 11:03:12 -0700 Subject: [PATCH 09/17] please save me from the madness --- .../sources/streams/concurrent/cursor.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/airbyte_cdk/sources/streams/concurrent/cursor.py b/airbyte_cdk/sources/streams/concurrent/cursor.py index 318076835..7f48618b0 100644 --- a/airbyte_cdk/sources/streams/concurrent/cursor.py +++ b/airbyte_cdk/sources/streams/concurrent/cursor.py @@ -4,6 +4,7 @@ import functools import logging +import os from abc import ABC, abstractmethod from typing import ( Any, @@ -17,6 +18,7 @@ Union, ) +from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.message import MessageRepository from airbyte_cdk.sources.streams import NO_CURSOR_STATE_KEY @@ -237,12 +239,41 @@ def _extract_cursor_value(self, record: Record) -> Any: return self._connector_state_converter.parse_value(self._cursor_field.extract_value(record)) def close_partition(self, partition: Partition) -> None: + test_env = os.getenv("PYTEST_CURRENT_TEST") + if test_env and "test_concurrent_declarative_source.py" in test_env: + self._message_repository.emit_message( + AirbyteMessage( + type=Type.LOG, + log=AirbyteLogMessage( + level=Level.INFO, message=f"Closing partition {partition.to_slice()}" + ), + ) + ) + self._message_repository.emit_message( + AirbyteMessage( + type=Type.LOG, + log=AirbyteLogMessage( + level=Level.INFO, message=f"\tstate before is {self._concurrent_state}" + ), + ) + ) + slice_count_before = len(self._concurrent_state.get("slices", [])) self._add_slice_to_state(partition) if slice_count_before < len( self._concurrent_state["slices"] ): # only emit if at least one slice has been processed self._merge_partitions() + if test_env and "test_concurrent_declarative_source.py" in test_env: + self._message_repository.emit_message( + AirbyteMessage( + type=Type.LOG, + log=AirbyteLogMessage( + level=Level.INFO, + message=f"\tstate after merged partition is {self._concurrent_state}", + ), + ) + ) self._emit_state_message() self._has_closed_at_least_one_slice = True From 0c2ab3a0f2139f181fc99a6ec02065764676b82a Mon Sep 17 00:00:00 2001 From: brianjlai Date: Wed, 20 Aug 2025 12:32:54 -0700 Subject: [PATCH 10/17] add stream name to cursor log statements --- airbyte_cdk/sources/streams/concurrent/cursor.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/airbyte_cdk/sources/streams/concurrent/cursor.py b/airbyte_cdk/sources/streams/concurrent/cursor.py index 7f48618b0..17a5ef43c 100644 --- a/airbyte_cdk/sources/streams/concurrent/cursor.py +++ b/airbyte_cdk/sources/streams/concurrent/cursor.py @@ -245,7 +245,8 @@ def close_partition(self, partition: Partition) -> None: AirbyteMessage( type=Type.LOG, log=AirbyteLogMessage( - level=Level.INFO, message=f"Closing partition {partition.to_slice()}" + level=Level.INFO, + message=f"[{self._stream_name}] Closing partition {partition.to_slice()}", ), ) ) @@ -253,7 +254,8 @@ def close_partition(self, partition: Partition) -> None: AirbyteMessage( type=Type.LOG, log=AirbyteLogMessage( - level=Level.INFO, message=f"\tstate before is {self._concurrent_state}" + level=Level.INFO, + message=f"\t[{self._stream_name}] state before is {self._concurrent_state}", ), ) ) @@ -270,7 +272,7 @@ def close_partition(self, partition: Partition) -> None: type=Type.LOG, log=AirbyteLogMessage( level=Level.INFO, - message=f"\tstate after merged partition is {self._concurrent_state}", + message=f"\t[{self._stream_name}] state after merged partition is {self._concurrent_state}", ), ) ) From a0c206c52726543efe35ce47b8cd2f7d86b22318 Mon Sep 17 00:00:00 2001 From: brianjlai Date: Wed, 20 Aug 2025 14:54:42 -0700 Subject: [PATCH 11/17] for the concurrent cursor add a lock around when state is read/write to prevent a race condition --- .../sources/streams/concurrent/cursor.py | 60 +++++++------------ 1 file changed, 22 insertions(+), 38 deletions(-) diff --git a/airbyte_cdk/sources/streams/concurrent/cursor.py b/airbyte_cdk/sources/streams/concurrent/cursor.py index 17a5ef43c..b2152e857 100644 --- a/airbyte_cdk/sources/streams/concurrent/cursor.py +++ b/airbyte_cdk/sources/streams/concurrent/cursor.py @@ -5,6 +5,7 @@ import functools import logging import os +import threading from abc import ABC, abstractmethod from typing import ( Any, @@ -176,6 +177,12 @@ def __init__( self._should_be_synced_logger_triggered = False self._clamping_strategy = clamping_strategy + # A lock is required when closing a partition because updating the cursor's concurrent_state is + # not thread safe. When multiple partitions are being closed by the cursor at the same time, it is + # possible for one partition to update concurrent_state after a second partition has already read + # the previous state. This can lead to the second partition overwriting the previous one's state. + self._lock = threading.Lock() + @property def state(self) -> MutableMapping[str, Any]: return self._connector_state_converter.convert_to_state_message( @@ -239,44 +246,21 @@ def _extract_cursor_value(self, record: Record) -> Any: return self._connector_state_converter.parse_value(self._cursor_field.extract_value(record)) def close_partition(self, partition: Partition) -> None: - test_env = os.getenv("PYTEST_CURRENT_TEST") - if test_env and "test_concurrent_declarative_source.py" in test_env: - self._message_repository.emit_message( - AirbyteMessage( - type=Type.LOG, - log=AirbyteLogMessage( - level=Level.INFO, - message=f"[{self._stream_name}] Closing partition {partition.to_slice()}", - ), - ) - ) - self._message_repository.emit_message( - AirbyteMessage( - type=Type.LOG, - log=AirbyteLogMessage( - level=Level.INFO, - message=f"\t[{self._stream_name}] state before is {self._concurrent_state}", - ), - ) - ) - - slice_count_before = len(self._concurrent_state.get("slices", [])) - self._add_slice_to_state(partition) - if slice_count_before < len( - self._concurrent_state["slices"] - ): # only emit if at least one slice has been processed - self._merge_partitions() - if test_env and "test_concurrent_declarative_source.py" in test_env: - self._message_repository.emit_message( - AirbyteMessage( - type=Type.LOG, - log=AirbyteLogMessage( - level=Level.INFO, - message=f"\t[{self._stream_name}] state after merged partition is {self._concurrent_state}", - ), - ) - ) - self._emit_state_message() + with self._lock: + slice_count_before = len(self._concurrent_state.get("slices", [])) + self._add_slice_to_state(partition) + if slice_count_before < len( + self._concurrent_state["slices"] + ): # only emit if at least one slice has been processed + self._merge_partitions() + self._emit_state_message() + # slice_count_before = len(self._concurrent_state.get("slices", [])) + # self._add_slice_to_state(partition) + # if slice_count_before < len( + # self._concurrent_state["slices"] + # ): # only emit if at least one slice has been processed + # self._merge_partitions() + # self._emit_state_message() self._has_closed_at_least_one_slice = True def _add_slice_to_state(self, partition: Partition) -> None: From 46d10f67f13101a5423ef46b2113d16e77890950 Mon Sep 17 00:00:00 2001 From: brianjlai Date: Wed, 20 Aug 2025 19:49:53 -0700 Subject: [PATCH 12/17] remove extra logging code and some code rabbit PR suggestions --- .../concurrent_source/concurrent_source.py | 5 +- .../sources/message/concurrent_repository.py | 23 -------- .../sources/streams/concurrent/cursor.py | 7 --- .../sources/streams/http/http_client.py | 6 --- .../test_connector_builder_handler.py | 54 +++++++++---------- 5 files changed, 28 insertions(+), 67 deletions(-) diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_source.py b/airbyte_cdk/sources/concurrent_source/concurrent_source.py index 3e84c0367..de2d93523 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_source.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_source.py @@ -1,9 +1,9 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # + import concurrent import logging -import os from queue import Queue from typing import Iterable, Iterator, List, Optional @@ -144,9 +144,6 @@ def _consume_from_queue( concurrent_stream_processor: ConcurrentReadProcessor, ) -> Iterable[AirbyteMessage]: while airbyte_message_or_record_or_exception := queue.get(): - test_env = os.getenv("PYTEST_CURRENT_TEST") - if test_env and "test_concurrent_declarative_source.py" in test_env: - self._logger.info(f"Processing and emitting {type(airbyte_message_or_record_or_exception)}: {airbyte_message_or_record_or_exception.__dict__}") yield from self._handle_item( airbyte_message_or_record_or_exception, concurrent_stream_processor, diff --git a/airbyte_cdk/sources/message/concurrent_repository.py b/airbyte_cdk/sources/message/concurrent_repository.py index 1ed07b9ff..e3bc7116a 100644 --- a/airbyte_cdk/sources/message/concurrent_repository.py +++ b/airbyte_cdk/sources/message/concurrent_repository.py @@ -29,23 +29,14 @@ def __init__(self, queue: Queue[QueueItem], message_repository: MessageRepositor self._queue = queue self._decorated_message_repository = message_repository - test_env = os.getenv("PYTEST_CURRENT_TEST") - self._log_messages_for_testing = ( - test_env and "test_concurrent_declarative_source.py" in test_env - ) - def emit_message(self, message: AirbyteMessage) -> None: self._decorated_message_repository.emit_message(message) for message in self._decorated_message_repository.consume_queue(): - if self._log_messages_for_testing: - self._log_message(message, "emit_message()") self._queue.put(message) def log_message(self, level: Level, message_provider: Callable[[], LogMessage]) -> None: self._decorated_message_repository.log_message(level, message_provider) for message in self._decorated_message_repository.consume_queue(): - if self._log_messages_for_testing: - self._log_message(message, "log_message()") self._queue.put(message) def consume_queue(self) -> Iterable[AirbyteMessage]: @@ -54,17 +45,3 @@ def consume_queue(self) -> Iterable[AirbyteMessage]: loading messages onto the queue processed on the main thread. """ yield from [] - - @staticmethod - def _log_message(message: AirbyteMessage, calling_method: str) -> None: - if message.type == MessageType.STATE: - if message.state and message.state.stream: - stream_name = message.state.stream.stream_descriptor.name - state = message.state.stream.stream_state.__dict__ - logger.info( - f"From {calling_method} -- emitting message of type {message.type} for stream {stream_name} with contents: {state}" - ) - else: - logger.info( - f"From {calling_method} -- emitting message of type {message.type} with contents: {message.__dict__}" - ) diff --git a/airbyte_cdk/sources/streams/concurrent/cursor.py b/airbyte_cdk/sources/streams/concurrent/cursor.py index b2152e857..7f6e62596 100644 --- a/airbyte_cdk/sources/streams/concurrent/cursor.py +++ b/airbyte_cdk/sources/streams/concurrent/cursor.py @@ -254,13 +254,6 @@ def close_partition(self, partition: Partition) -> None: ): # only emit if at least one slice has been processed self._merge_partitions() self._emit_state_message() - # slice_count_before = len(self._concurrent_state.get("slices", [])) - # self._add_slice_to_state(partition) - # if slice_count_before < len( - # self._concurrent_state["slices"] - # ): # only emit if at least one slice has been processed - # self._merge_partitions() - # self._emit_state_message() self._has_closed_at_least_one_slice = True def _add_slice_to_state(self, partition: Partition) -> None: diff --git a/airbyte_cdk/sources/streams/http/http_client.py b/airbyte_cdk/sources/streams/http/http_client.py index e10c95890..ff3c8e733 100644 --- a/airbyte_cdk/sources/streams/http/http_client.py +++ b/airbyte_cdk/sources/streams/http/http_client.py @@ -127,12 +127,6 @@ def cache_filename(self) -> str: Override if needed. Return the name of cache file Note that if the environment variable REQUEST_CACHE_PATH is not set, the cache will be in-memory only. """ - # This is a hack so that we ensure that the same cache is not used across different test files - # because we observed some flakiness in tests when running on CI - # https://github.com/airbytehq/airbyte-python-cdk/pull/688 - # https://github.com/airbytehq/airbyte-python-cdk/pull/712 - if os.getenv("PYTEST_CURRENT_TEST"): - return f"{self._name}-{os.getenv('PYTEST_CURRENT_TEST')}.sqlite" return f"{self._name}.sqlite" def _request_session(self) -> requests.Session: diff --git a/unit_tests/connector_builder/test_connector_builder_handler.py b/unit_tests/connector_builder/test_connector_builder_handler.py index 063eb4482..643878eec 100644 --- a/unit_tests/connector_builder/test_connector_builder_handler.py +++ b/unit_tests/connector_builder/test_connector_builder_handler.py @@ -17,7 +17,6 @@ from airbyte_cdk import connector_builder from airbyte_cdk.connector_builder.connector_builder_handler import ( - TestLimits, create_source, get_limits, resolve_manifest, @@ -893,12 +892,13 @@ def test_handle_429_response(): {"result": [{"error": "too many requests"}], "_metadata": {"next": "next"}} ) + config = copy.deepcopy(TEST_READ_CONFIG) + # Add backoff strategy to avoid default endless backoff loop - TEST_READ_CONFIG["__injected_declarative_manifest"]["definitions"]["retriever"]["requester"][ + config["__injected_declarative_manifest"]["definitions"]["retriever"]["requester"][ "error_handler" ] = {"backoff_strategies": [{"type": "ConstantBackoffStrategy", "backoff_time_in_seconds": 5}]} - config = TEST_READ_CONFIG limits = TestLimits() catalog = ConfiguredAirbyteCatalogSerializer.load(CONFIGURED_CATALOG) source = create_source(config=config, limits=limits, catalog=catalog, state=None) @@ -1256,7 +1256,7 @@ def test_handle_read_external_requests(deployment_mode, url_base, expected_error ] ) - test_manifest = MANIFEST + test_manifest = copy.deepcopy(MANIFEST) test_manifest["streams"][0]["$parameters"]["url_base"] = url_base config = {"__injected_declarative_manifest": test_manifest} @@ -1350,7 +1350,7 @@ def test_handle_read_external_oauth_request(deployment_mode, token_url, expected "refresh_token": "john", } - test_manifest = MANIFEST + test_manifest = copy.deepcopy(MANIFEST) test_manifest["definitions"]["retriever"]["requester"]["authenticator"] = ( oauth_authenticator_config ) @@ -1486,11 +1486,11 @@ def test_full_resolve_manifest(valid_resolve_manifest_config_file): "type": "RequestOption", "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", "$parameters": { "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", }, }, "page_token_option": { @@ -1498,11 +1498,11 @@ def test_full_resolve_manifest(valid_resolve_manifest_config_file): "type": "RequestPath", "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", "$parameters": { "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", }, }, "pagination_strategy": { @@ -1511,20 +1511,20 @@ def test_full_resolve_manifest(valid_resolve_manifest_config_file): "page_size": 2, "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", "$parameters": { "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", }, }, "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", "$parameters": { "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", }, }, "partition_router": { @@ -1533,11 +1533,11 @@ def test_full_resolve_manifest(valid_resolve_manifest_config_file): "cursor_field": "item_id", "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", "$parameters": { "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", }, }, "requester": { @@ -1547,22 +1547,22 @@ def test_full_resolve_manifest(valid_resolve_manifest_config_file): "api_token": "{{ config.apikey }}", "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", "$parameters": { "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", }, }, "request_parameters": {"a_param": "10"}, "type": "HttpRequester", "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", "$parameters": { "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", }, }, "record_selector": { @@ -1571,40 +1571,40 @@ def test_full_resolve_manifest(valid_resolve_manifest_config_file): "type": "DpathExtractor", "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", "$parameters": { "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", }, }, "type": "RecordSelector", "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", "$parameters": { "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", }, }, "type": "SimpleRetriever", "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", "$parameters": { "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", }, }, "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", "$parameters": { "name": "stream_with_custom_requester", "primary_key": "id", - "url_base": "https://10.0.27.27/api/v1/", + "url_base": "https://api.sendgrid.com", }, "dynamic_stream_name": None, }, From df5226196fc750c121cf168ca6b52b88d958b5b0 Mon Sep 17 00:00:00 2001 From: brianjlai Date: Wed, 20 Aug 2025 21:05:45 -0700 Subject: [PATCH 13/17] pr feedback --- .../stream_slicers/declarative_partition_generator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py b/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py index 4a511fe70..809936ae0 100644 --- a/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py +++ b/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py @@ -82,14 +82,14 @@ def __init__( self._hash = SliceHasher.hash(self._stream_name, self._stream_slice) def read(self) -> Iterable[Record]: - if self._max_records_limit: + if self._max_records_limit is not None: global total_record_counter if total_record_counter >= self._max_records_limit: return for stream_data in self._retriever.read_records( self._schema_loader.get_json_schema(), self._stream_slice ): - if self._max_records_limit: + if self._max_records_limit is not None: if total_record_counter >= self._max_records_limit: break @@ -107,7 +107,7 @@ def read(self) -> Iterable[Record]: else: self._message_repository.emit_message(stream_data) - if self._max_records_limit: + if self._max_records_limit is not None: total_record_counter += 1 def to_slice(self) -> Optional[Mapping[str, Any]]: From 38b4666982048368a78211f28efabffb07f277f1 Mon Sep 17 00:00:00 2001 From: brianjlai Date: Thu, 21 Aug 2025 16:34:05 -0700 Subject: [PATCH 14/17] add comment about thread safety for observe() and follow up items --- airbyte_cdk/sources/streams/concurrent/cursor.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/airbyte_cdk/sources/streams/concurrent/cursor.py b/airbyte_cdk/sources/streams/concurrent/cursor.py index 7f6e62596..0c423f72a 100644 --- a/airbyte_cdk/sources/streams/concurrent/cursor.py +++ b/airbyte_cdk/sources/streams/concurrent/cursor.py @@ -231,6 +231,13 @@ def _get_concurrent_state( ) def observe(self, record: Record) -> None: + # Because observe writes to the most_recent_cursor_value_per_partition mapping, + # it is not thread-safe. However, this shouldn't lead to concurrency issues + # because observe() is only invoked on the main thread and the map is broken + # down by partition which should not have conflicting read/write. + # + # If we were to add thread safety, we should implement a lock per-partition + # which is instantiated during stream_slices() most_recent_cursor_value = self._most_recent_cursor_value_per_partition.get( record.associated_slice ) From 3316766ef2c954ce5c2a66868a6acd219dd736b6 Mon Sep 17 00:00:00 2001 From: brianjlai Date: Fri, 22 Aug 2025 10:19:16 -0700 Subject: [PATCH 15/17] pr feedback correct comment --- airbyte_cdk/sources/streams/concurrent/cursor.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/airbyte_cdk/sources/streams/concurrent/cursor.py b/airbyte_cdk/sources/streams/concurrent/cursor.py index 0c423f72a..686b074c5 100644 --- a/airbyte_cdk/sources/streams/concurrent/cursor.py +++ b/airbyte_cdk/sources/streams/concurrent/cursor.py @@ -4,7 +4,6 @@ import functools import logging -import os import threading from abc import ABC, abstractmethod from typing import ( @@ -19,7 +18,6 @@ Union, ) -from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.message import MessageRepository from airbyte_cdk.sources.streams import NO_CURSOR_STATE_KEY @@ -233,8 +231,12 @@ def _get_concurrent_state( def observe(self, record: Record) -> None: # Because observe writes to the most_recent_cursor_value_per_partition mapping, # it is not thread-safe. However, this shouldn't lead to concurrency issues - # because observe() is only invoked on the main thread and the map is broken - # down by partition which should not have conflicting read/write. + # because observe() is only invoked in two ways both of which aren't conflicting: + # - ConcurrentReadProcessor.on_record(): Since records are observed on the main + # thread and so there aren't concurrent operations. Partitions are also split by key. + # - PartitionReader.process_partition(): Because the map is broken down according to + # partition, concurrent threads processing only read/write from different keys which + # avoids any conflicts. # # If we were to add thread safety, we should implement a lock per-partition # which is instantiated during stream_slices() From 78e04692e0cf5428a486ec61644bf4ee131c462e Mon Sep 17 00:00:00 2001 From: brianjlai Date: Fri, 22 Aug 2025 10:50:34 -0700 Subject: [PATCH 16/17] remove unneeded observe() and update comment --- .../concurrent_source/concurrent_read_processor.py | 1 - airbyte_cdk/sources/streams/concurrent/cursor.py | 11 ++++------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py index bbc359a58..905999a4d 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py @@ -154,7 +154,6 @@ def on_record(self, record: Record) -> Iterable[AirbyteMessage]: stream.as_airbyte_stream(), AirbyteStreamStatus.RUNNING ) self._record_counter[stream.name] += 1 - stream.cursor.observe(record) yield message yield from self._message_repository.consume_queue() diff --git a/airbyte_cdk/sources/streams/concurrent/cursor.py b/airbyte_cdk/sources/streams/concurrent/cursor.py index 686b074c5..ca63a6901 100644 --- a/airbyte_cdk/sources/streams/concurrent/cursor.py +++ b/airbyte_cdk/sources/streams/concurrent/cursor.py @@ -230,13 +230,10 @@ def _get_concurrent_state( def observe(self, record: Record) -> None: # Because observe writes to the most_recent_cursor_value_per_partition mapping, - # it is not thread-safe. However, this shouldn't lead to concurrency issues - # because observe() is only invoked in two ways both of which aren't conflicting: - # - ConcurrentReadProcessor.on_record(): Since records are observed on the main - # thread and so there aren't concurrent operations. Partitions are also split by key. - # - PartitionReader.process_partition(): Because the map is broken down according to - # partition, concurrent threads processing only read/write from different keys which - # avoids any conflicts. + # it is not thread-safe. However, this shouldn't lead to concurrency issues because + # observe() is only invoked by PartitionReader.process_partition(). Since the map is + # broken down according to partition, concurrent threads processing only read/write + # from different keys which avoids any conflicts. # # If we were to add thread safety, we should implement a lock per-partition # which is instantiated during stream_slices() From 772c77aa7b4f15c4ef18d3834f46aed56d6467e5 Mon Sep 17 00:00:00 2001 From: brianjlai Date: Fri, 22 Aug 2025 12:09:58 -0700 Subject: [PATCH 17/17] fix test to use PartitionReader where cursor close is now invoked --- .../sources/streams/test_stream_read.py | 39 +++++++++++-------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/unit_tests/sources/streams/test_stream_read.py b/unit_tests/sources/streams/test_stream_read.py index bf13ac351..cf550f8cf 100644 --- a/unit_tests/sources/streams/test_stream_read.py +++ b/unit_tests/sources/streams/test_stream_read.py @@ -4,6 +4,7 @@ import logging from copy import deepcopy +from queue import Queue from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Union from unittest.mock import Mock @@ -589,7 +590,10 @@ def test_concurrent_incremental_read_two_slices(): *records_partition_2, ] - expected_state = _create_state_message( + expected_state_1 = _create_state_message( + "__mock_stream", {"1": {"created_at": slice_timestamp_1}} + ) + expected_state_2 = _create_state_message( "__mock_stream", {"1": {"created_at": slice_timestamp_1}, "2": {"created_at": slice_timestamp_2}}, ) @@ -617,26 +621,27 @@ def test_concurrent_incremental_read_two_slices(): for record in expected_records: assert record in actual_records - # We need run on_record to update cursor with record cursor value - for record in actual_records: - list( - handler.on_record( - Record( - data=record, - stream_name="__mock_stream", - ) - ) - ) + # We need to process partitions generated by a PartitionReader in order to trigger + # the ConcurrentCursor.close_partition() flow and validate state is updated with + # the observed record values + partition_reader = PartitionReader(queue=Mock(spec=Queue)) + assert isinstance(stream, StreamFacade) + abstract_stream = stream._abstract_stream + for partition in abstract_stream.generate_partitions(): + partition_reader.process_partition(partition=partition, cursor=cursor) assert len(actual_records) == len(expected_records) - # We don't have a real source that reads from the message_repository for state, so we read from the queue directly to verify - # the cursor observed records correctly and updated partition states - mock_partition = Mock() - cursor.close_partition(mock_partition) actual_state = [state for state in message_repository.consume_queue()] - assert len(actual_state) == 1 - assert actual_state[0] == expected_state + assert len(actual_state) == 2 + assert ( + actual_state[0].state.stream.stream_state.__dict__ + == expected_state_1.state.stream.stream_state.__dict__ + ) + assert ( + actual_state[1].state.stream.stream_state.__dict__ + == expected_state_2.state.stream.stream_state.__dict__ + ) def setup_stream_dependencies(configured_json_schema):