diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py index 905999a4d..5d5a2c4cc 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py @@ -41,6 +41,7 @@ def __init__( ): """ This class is responsible for handling items from a concurrent stream read process. + :param stream_instances_to_read_from: List of streams to read from :param partition_enqueuer: PartitionEnqueuer instance :param thread_pool_manager: ThreadPoolManager instance @@ -50,21 +51,25 @@ def __init__( :param partition_reader: PartitionReader instance """ self._stream_name_to_instance = {s.name: s for s in stream_instances_to_read_from} - self._record_counter = {} + self._record_counter: Dict[str, int] = {} self._streams_to_running_partitions: Dict[str, Set[Partition]] = {} for stream in stream_instances_to_read_from: self._streams_to_running_partitions[stream.name] = set() self._record_counter[stream.name] = 0 self._thread_pool_manager = thread_pool_manager self._partition_enqueuer = partition_enqueuer - self._stream_instances_to_start_partition_generation = stream_instances_to_read_from + self._stream_instances_to_start_partition_generation = list(stream_instances_to_read_from) self._streams_currently_generating_partitions: List[str] = [] self._logger = logger self._slice_logger = slice_logger self._message_repository = message_repository self._partition_reader = partition_reader self._streams_done: Set[str] = set() - self._exceptions_per_stream_name: dict[str, List[Exception]] = {} + self._exceptions_per_stream_name: Dict[str, List[Exception]] = {} + + # Track active concurrency groups and deferred streams + self._active_concurrency_groups: Set[str] = set() + self._deferred_streams: List[AbstractStream] = [] def on_partition_generation_completed( self, sentinel: PartitionGenerationCompletedSentinel @@ -182,13 +187,22 @@ def _flag_exception(self, stream_name: str, exception: Exception) -> None: def start_next_partition_generator(self) -> Optional[AirbyteMessage]: """ Start the next partition generator. - 1. Pop the next stream to read from + + 1. Find the next stream that can be started (respecting concurrency groups) 2. Submit the partition generator to the thread pool manager 3. Add the stream to the list of streams currently generating partitions - 4. Return a stream status message + 4. Mark the concurrency group as active if applicable + 5. Return a stream status message """ - if self._stream_instances_to_start_partition_generation: - stream = self._stream_instances_to_start_partition_generation.pop(0) + stream = self._get_next_eligible_stream() + if stream: + concurrency_group = stream.concurrency_group + if concurrency_group: + self._active_concurrency_groups.add(concurrency_group) + self._logger.debug( + f"Stream {stream.name} activated concurrency group '{concurrency_group}'" + ) + self._thread_pool_manager.submit(self._partition_enqueuer.generate_partitions, stream) self._streams_currently_generating_partitions.append(stream.name) self._logger.info(f"Marking stream {stream.name} as STARTED") @@ -200,14 +214,52 @@ def start_next_partition_generator(self) -> Optional[AirbyteMessage]: else: return None + def _get_next_eligible_stream(self) -> Optional[AbstractStream]: + """ + Get the next stream that can be started, respecting concurrency groups. + + Streams with a concurrency group that is already active will be deferred + until the group becomes inactive. + + :return: The next eligible stream, or None if no streams are available + """ + eligible_stream: Optional[AbstractStream] = None + streams_to_defer: List[AbstractStream] = [] + + while self._stream_instances_to_start_partition_generation: + stream = self._stream_instances_to_start_partition_generation.pop(0) + concurrency_group = stream.concurrency_group + + if concurrency_group and concurrency_group in self._active_concurrency_groups: + # This stream's concurrency group is active, defer it + streams_to_defer.append(stream) + self._logger.debug( + f"Deferring stream {stream.name} because concurrency group " + f"'{concurrency_group}' is active" + ) + else: + # This stream can be started + eligible_stream = stream + break + + # Add deferred streams back to the list (at the end) + self._deferred_streams.extend(streams_to_defer) + + return eligible_stream + def is_done(self) -> bool: """ - This method is called to check if the sync is done. + Check if the sync is done. + The sync is done when: 1. There are no more streams generating partitions - 2. There are no more streams to read from + 2. There are no more streams to read from (including deferred streams) 3. All partitions for all streams are closed """ + # Check if there are still deferred streams waiting + if self._deferred_streams: + return False + is_done = all( [ self._is_stream_done(stream_name) @@ -240,9 +292,72 @@ def _on_stream_is_done(self, stream_name: str) -> Iterable[AirbyteMessage]: yield from self._message_repository.consume_queue() self._logger.info(f"Finished syncing {stream.name}") self._streams_done.add(stream_name) + + # Deactivate concurrency group if this stream had one and no other streams + # in the same group are still running + concurrency_group = stream.concurrency_group + if concurrency_group and not self._is_concurrency_group_active(concurrency_group): + self._active_concurrency_groups.discard(concurrency_group) + self._logger.debug( + f"Deactivated concurrency group '{concurrency_group}' after stream " + f"{stream_name} completed" + ) + # Re-queue deferred streams that were waiting for this group + self._requeue_deferred_streams_for_group(concurrency_group) + stream_status = ( AirbyteStreamStatus.INCOMPLETE if self._exceptions_per_stream_name.get(stream_name, []) else AirbyteStreamStatus.COMPLETE ) yield stream_status_as_airbyte_message(stream.as_airbyte_stream(), stream_status) + + def _is_concurrency_group_active(self, concurrency_group: str) -> bool: + """ + Check if a concurrency group still has active streams. + + A group is active if any stream in the group is either generating partitions + or has running partitions. + + :param concurrency_group: The concurrency group to check + :return: True if the group has active streams, False otherwise + """ + for stream_name in self._streams_currently_generating_partitions: + stream = self._stream_name_to_instance[stream_name] + if stream.concurrency_group == concurrency_group: + return True + + for stream_name, partitions in self._streams_to_running_partitions.items(): + if partitions: # Has running partitions + stream = self._stream_name_to_instance[stream_name] + if stream.concurrency_group == concurrency_group: + return True + + return False + + def _requeue_deferred_streams_for_group(self, concurrency_group: str) -> None: + """ + Move deferred streams that were waiting for a concurrency group back to the main queue. + + :param concurrency_group: The concurrency group that just became inactive + """ + streams_to_requeue: List[AbstractStream] = [] + remaining_deferred: List[AbstractStream] = [] + + for stream in self._deferred_streams: + if stream.concurrency_group == concurrency_group: + streams_to_requeue.append(stream) + else: + remaining_deferred.append(stream) + + if streams_to_requeue: + self._logger.debug( + f"Re-queuing {len(streams_to_requeue)} deferred stream(s) for concurrency " + f"group '{concurrency_group}': {[s.name for s in streams_to_requeue]}" + ) + # Add to the front of the queue so they get processed next + self._stream_instances_to_start_partition_generation = ( + streams_to_requeue + self._stream_instances_to_start_partition_generation + ) + + self._deferred_streams = remaining_deferred diff --git a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml index e04a82c0d..3c7ee1b34 100644 --- a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml +++ b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml @@ -1553,6 +1553,16 @@ definitions: default: "" example: - "Users" + concurrency_group: + title: Concurrency Group + description: > + Streams with the same concurrency group will be processed serially with respect to each other. + This is useful for APIs that limit concurrent requests to certain endpoints, such as scroll-based + pagination APIs that only allow one active scroll at a time. Streams without a concurrency group + (or with different groups) will be processed concurrently as normal. + type: string + example: + - "scroll" retriever: title: Retriever description: Component used to coordinate how records are extracted across stream slices and request pages. diff --git a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py index b78a07021..d818c390d 100644 --- a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py +++ b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py @@ -2497,6 +2497,12 @@ class Config: type: Literal["DeclarativeStream"] name: Optional[str] = Field("", description="The stream name.", example=["Users"], title="Name") + concurrency_group: Optional[str] = Field( + None, + description="Streams with the same concurrency group will be processed serially with respect to each other. This is useful for APIs that limit concurrent requests to certain endpoints, such as scroll-based pagination APIs that only allow one active scroll at a time. Streams without a concurrency group (or with different groups) will be processed concurrently as normal.", + example=["scroll"], + title="Concurrency Group", + ) retriever: Union[SimpleRetriever, AsyncRetriever, CustomRetriever] = Field( ..., description="Component used to coordinate how records are extracted across stream slices and request pages.", diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index 3a772b691..07079c67b 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -2118,6 +2118,7 @@ def create_default_stream( logger=logging.getLogger(f"airbyte.{stream_name}"), cursor=concurrent_cursor, supports_file_transfer=hasattr(model, "file_uploader") and bool(model.file_uploader), + concurrency_group=model.concurrency_group, ) def _migrate_state(self, model: DeclarativeStreamModel, config: Config) -> None: diff --git a/airbyte_cdk/sources/streams/concurrent/abstract_stream.py b/airbyte_cdk/sources/streams/concurrent/abstract_stream.py index 667d088ab..ddcf2e3f2 100644 --- a/airbyte_cdk/sources/streams/concurrent/abstract_stream.py +++ b/airbyte_cdk/sources/streams/concurrent/abstract_stream.py @@ -90,3 +90,18 @@ def check_availability(self) -> StreamAvailability: """ :return: If the stream is available and if not, why """ + + @property + def concurrency_group(self) -> Optional[str]: + """ + Returns the concurrency group for this stream. + + Streams with the same non-None concurrency group will be processed serially + with respect to each other. This is useful for APIs that limit concurrent + requests to certain endpoints (e.g., scroll-based pagination APIs that only + allow one active scroll at a time). + + :return: The concurrency group name, or None if the stream can be processed + concurrently with all other streams. + """ + return None diff --git a/airbyte_cdk/sources/streams/concurrent/default_stream.py b/airbyte_cdk/sources/streams/concurrent/default_stream.py index f5d4ccf2e..9f0656a70 100644 --- a/airbyte_cdk/sources/streams/concurrent/default_stream.py +++ b/airbyte_cdk/sources/streams/concurrent/default_stream.py @@ -26,6 +26,7 @@ def __init__( cursor: Cursor, namespace: Optional[str] = None, supports_file_transfer: bool = False, + concurrency_group: Optional[str] = None, ) -> None: self._stream_partition_generator = partition_generator self._name = name @@ -36,6 +37,7 @@ def __init__( self._cursor = cursor self._namespace = namespace self._supports_file_transfer = supports_file_transfer + self._concurrency_group = concurrency_group def generate_partitions(self) -> Iterable[Partition]: yield from self._stream_partition_generator.generate() @@ -94,6 +96,10 @@ def log_stream_sync_configuration(self) -> None: def cursor(self) -> Cursor: return self._cursor + @property + def concurrency_group(self) -> Optional[str]: + return self._concurrency_group + def check_availability(self) -> StreamAvailability: """ Check stream availability by attempting to read the first record of the stream. diff --git a/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py b/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py index a681f75eb..040649aad 100644 --- a/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py +++ b/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py @@ -792,3 +792,210 @@ def test_start_next_partition_generator(self): self._thread_pool_manager.submit.assert_called_with( self._partition_enqueuer.generate_partitions, self._stream ) + + def test_concurrency_group_defers_stream_when_group_is_active(self): + """Test that streams with the same concurrency group are deferred when the group is active.""" + # Create two streams with the same concurrency group + stream1 = Mock(spec=AbstractStream) + stream1.name = "stream1" + stream1.concurrency_group = "scroll" + stream1.as_airbyte_stream.return_value = AirbyteStream( + name="stream1", + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh], + ) + + stream2 = Mock(spec=AbstractStream) + stream2.name = "stream2" + stream2.concurrency_group = "scroll" + stream2.as_airbyte_stream.return_value = AirbyteStream( + name="stream2", + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh], + ) + + handler = ConcurrentReadProcessor( + [stream1, stream2], + self._partition_enqueuer, + self._thread_pool_manager, + self._logger, + self._slice_logger, + self._message_repository, + self._partition_reader, + ) + + # Start the first stream - should succeed + handler.start_next_partition_generator() + assert "stream1" in handler._streams_currently_generating_partitions + assert "scroll" in handler._active_concurrency_groups + + # Try to start the second stream - should be deferred + result = handler.start_next_partition_generator() + assert result is None # No stream started + assert "stream2" not in handler._streams_currently_generating_partitions + assert stream2 in handler._deferred_streams + + def test_concurrency_group_allows_streams_without_group_to_run_concurrently(self): + """Test that streams without a concurrency group can run concurrently.""" + stream1 = Mock(spec=AbstractStream) + stream1.name = "stream1" + stream1.concurrency_group = None + stream1.as_airbyte_stream.return_value = AirbyteStream( + name="stream1", + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh], + ) + + stream2 = Mock(spec=AbstractStream) + stream2.name = "stream2" + stream2.concurrency_group = None + stream2.as_airbyte_stream.return_value = AirbyteStream( + name="stream2", + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh], + ) + + handler = ConcurrentReadProcessor( + [stream1, stream2], + self._partition_enqueuer, + self._thread_pool_manager, + self._logger, + self._slice_logger, + self._message_repository, + self._partition_reader, + ) + + # Start both streams - both should succeed + handler.start_next_partition_generator() + handler.start_next_partition_generator() + + assert "stream1" in handler._streams_currently_generating_partitions + assert "stream2" in handler._streams_currently_generating_partitions + assert len(handler._deferred_streams) == 0 + + def test_concurrency_group_allows_different_groups_to_run_concurrently(self): + """Test that streams with different concurrency groups can run concurrently.""" + stream1 = Mock(spec=AbstractStream) + stream1.name = "stream1" + stream1.concurrency_group = "group_a" + stream1.as_airbyte_stream.return_value = AirbyteStream( + name="stream1", + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh], + ) + + stream2 = Mock(spec=AbstractStream) + stream2.name = "stream2" + stream2.concurrency_group = "group_b" + stream2.as_airbyte_stream.return_value = AirbyteStream( + name="stream2", + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh], + ) + + handler = ConcurrentReadProcessor( + [stream1, stream2], + self._partition_enqueuer, + self._thread_pool_manager, + self._logger, + self._slice_logger, + self._message_repository, + self._partition_reader, + ) + + # Start both streams - both should succeed since they have different groups + handler.start_next_partition_generator() + handler.start_next_partition_generator() + + assert "stream1" in handler._streams_currently_generating_partitions + assert "stream2" in handler._streams_currently_generating_partitions + assert "group_a" in handler._active_concurrency_groups + assert "group_b" in handler._active_concurrency_groups + assert len(handler._deferred_streams) == 0 + + @freezegun.freeze_time("2020-01-01T00:00:00") + def test_concurrency_group_requeues_deferred_stream_when_group_becomes_inactive(self): + """Test that deferred streams are re-queued and started when their concurrency group becomes inactive.""" + stream1 = Mock(spec=AbstractStream) + stream1.name = "stream1" + stream1.concurrency_group = "scroll" + stream1.as_airbyte_stream.return_value = AirbyteStream( + name="stream1", + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh], + ) + + stream2 = Mock(spec=AbstractStream) + stream2.name = "stream2" + stream2.concurrency_group = "scroll" + stream2.as_airbyte_stream.return_value = AirbyteStream( + name="stream2", + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh], + ) + + self._message_repository.consume_queue.return_value = [] + + handler = ConcurrentReadProcessor( + [stream1, stream2], + self._partition_enqueuer, + self._thread_pool_manager, + self._logger, + self._slice_logger, + self._message_repository, + self._partition_reader, + ) + + # Start stream1 + handler.start_next_partition_generator() + assert "stream1" in handler._streams_currently_generating_partitions + + # Try to start stream2 - should be deferred + handler.start_next_partition_generator() + assert stream2 in handler._deferred_streams + + # Complete stream1's partition generation (stream1 has no partitions, so it's done) + sentinel = PartitionGenerationCompletedSentinel(stream1) + list(handler.on_partition_generation_completed(sentinel)) + + # stream2 should now be started (re-queued and then started by on_partition_generation_completed) + assert stream2 not in handler._deferred_streams + assert "stream2" in handler._streams_currently_generating_partitions + + def test_is_done_returns_false_when_deferred_streams_exist(self): + """Test that is_done returns False when there are deferred streams.""" + stream1 = Mock(spec=AbstractStream) + stream1.name = "stream1" + stream1.concurrency_group = "scroll" + stream1.as_airbyte_stream.return_value = AirbyteStream( + name="stream1", + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh], + ) + + stream2 = Mock(spec=AbstractStream) + stream2.name = "stream2" + stream2.concurrency_group = "scroll" + stream2.as_airbyte_stream.return_value = AirbyteStream( + name="stream2", + json_schema={}, + supported_sync_modes=[SyncMode.full_refresh], + ) + + handler = ConcurrentReadProcessor( + [stream1, stream2], + self._partition_enqueuer, + self._thread_pool_manager, + self._logger, + self._slice_logger, + self._message_repository, + self._partition_reader, + ) + + # Start stream1 and defer stream2 + handler.start_next_partition_generator() + handler.start_next_partition_generator() + + # Even if stream1 is done, is_done should return False because stream2 is deferred + handler._streams_done.add("stream1") + assert not handler.is_done()