diff --git a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml index 1cb53104b..2a1b23242 100644 --- a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml +++ b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml @@ -2894,7 +2894,7 @@ definitions: title: Lazy Read Pointer description: If set, this will enable lazy reading, using the initial read of parent records to extract child records. type: array - default: [ ] + default: [] items: - type: string interpolation_context: @@ -3199,7 +3199,7 @@ definitions: properties: type: type: string - enum: [ StateDelegatingStream ] + enum: [StateDelegatingStream] name: title: Name description: The stream name. @@ -3254,12 +3254,14 @@ definitions: - "$ref": "#/definitions/CustomPartitionRouter" - "$ref": "#/definitions/ListPartitionRouter" - "$ref": "#/definitions/SubstreamPartitionRouter" + - "$ref": "#/definitions/GroupingPartitionRouter" - type: array items: anyOf: - "$ref": "#/definitions/CustomPartitionRouter" - "$ref": "#/definitions/ListPartitionRouter" - "$ref": "#/definitions/SubstreamPartitionRouter" + - "$ref": "#/definitions/GroupingPartitionRouter" decoder: title: Decoder description: Component decoding the response so records can be extracted. @@ -3414,12 +3416,14 @@ definitions: - "$ref": "#/definitions/CustomPartitionRouter" - "$ref": "#/definitions/ListPartitionRouter" - "$ref": "#/definitions/SubstreamPartitionRouter" + - "$ref": "#/definitions/GroupingPartitionRouter" - type: array items: anyOf: - "$ref": "#/definitions/CustomPartitionRouter" - "$ref": "#/definitions/ListPartitionRouter" - "$ref": "#/definitions/SubstreamPartitionRouter" + - "$ref": "#/definitions/GroupingPartitionRouter" decoder: title: Decoder description: Component decoding the response so records can be extracted. @@ -3536,6 +3540,44 @@ definitions: $parameters: type: object additionalProperties: true + GroupingPartitionRouter: + title: Grouping Partition Router + description: > + A decorator on top of a partition router that groups partitions into batches of a specified size. + This is useful for APIs that support filtering by multiple partition keys in a single request. + Note that per-partition incremental syncs may not work as expected because the grouping + of partitions might change between syncs, potentially leading to inconsistent state tracking. + type: object + required: + - type + - group_size + - underlying_partition_router + properties: + type: + type: string + enum: [GroupingPartitionRouter] + group_size: + title: Group Size + description: The number of partitions to include in each group. This determines how many partition values are batched together in a single slice. + type: integer + examples: + - 10 + - 50 + underlying_partition_router: + title: Underlying Partition Router + description: The partition router whose output will be grouped. This can be any valid partition router component. + anyOf: + - "$ref": "#/definitions/CustomPartitionRouter" + - "$ref": "#/definitions/ListPartitionRouter" + - "$ref": "#/definitions/SubstreamPartitionRouter" + deduplicate: + title: Deduplicate Partitions + description: If true, ensures that partitions are unique within each group by removing duplicates based on the partition key. + type: boolean + default: true + $parameters: + type: object + additionalProperties: true WaitUntilTimeFromHeader: title: Wait Until Time Defined In Response Header description: Extract time at which we can retry the request from response header and wait for the difference between now and that time. diff --git a/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py b/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py index 715589026..a0c541dc4 100644 --- a/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py @@ -79,6 +79,7 @@ def __init__( connector_state_manager: ConnectorStateManager, connector_state_converter: AbstractStreamStateConverter, cursor_field: CursorField, + use_global_cursor: bool = False, ) -> None: self._global_cursor: Optional[StreamState] = {} self._stream_name = stream_name @@ -106,7 +107,7 @@ def __init__( self._lookback_window: int = 0 self._parent_state: Optional[StreamState] = None self._number_of_partitions: int = 0 - self._use_global_cursor: bool = False + self._use_global_cursor: bool = use_global_cursor self._partition_serializer = PerPartitionKeySerializer() # Track the last time a state message was emitted self._last_emission_time: float = 0.0 diff --git a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py index e4fb459ff..06bb31230 100644 --- a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py +++ b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py @@ -2301,7 +2301,15 @@ class SimpleRetriever(BaseModel): CustomPartitionRouter, ListPartitionRouter, SubstreamPartitionRouter, - List[Union[CustomPartitionRouter, ListPartitionRouter, SubstreamPartitionRouter]], + GroupingPartitionRouter, + List[ + Union[ + CustomPartitionRouter, + ListPartitionRouter, + SubstreamPartitionRouter, + GroupingPartitionRouter, + ] + ], ] ] = Field( [], @@ -2379,7 +2387,15 @@ class AsyncRetriever(BaseModel): CustomPartitionRouter, ListPartitionRouter, SubstreamPartitionRouter, - List[Union[CustomPartitionRouter, ListPartitionRouter, SubstreamPartitionRouter]], + GroupingPartitionRouter, + List[ + Union[ + CustomPartitionRouter, + ListPartitionRouter, + SubstreamPartitionRouter, + GroupingPartitionRouter, + ] + ], ] ] = Field( [], @@ -2431,6 +2447,29 @@ class SubstreamPartitionRouter(BaseModel): parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") +class GroupingPartitionRouter(BaseModel): + type: Literal["GroupingPartitionRouter"] + group_size: int = Field( + ..., + description="The number of partitions to include in each group. This determines how many partition values are batched together in a single slice.", + examples=[10, 50], + title="Group Size", + ) + underlying_partition_router: Union[ + CustomPartitionRouter, ListPartitionRouter, SubstreamPartitionRouter + ] = Field( + ..., + description="The partition router whose output will be grouped. This can be any valid partition router component.", + title="Underlying Partition Router", + ) + deduplicate: Optional[bool] = Field( + True, + description="If true, ensures that partitions are unique within each group by removing duplicates based on the partition key.", + title="Deduplicate Partitions", + ) + parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters") + + class HttpComponentsResolver(BaseModel): type: Literal["HttpComponentsResolver"] retriever: Union[AsyncRetriever, CustomRetriever, SimpleRetriever] = Field( diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index 86e880b20..7dd8f50b9 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -227,6 +227,9 @@ from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( FlattenFields as FlattenFieldsModel, ) +from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( + GroupingPartitionRouter as GroupingPartitionRouterModel, +) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( GzipDecoder as GzipDecoderModel, ) @@ -385,6 +388,7 @@ ) from airbyte_cdk.sources.declarative.partition_routers import ( CartesianProductStreamSlicer, + GroupingPartitionRouter, ListPartitionRouter, PartitionRouter, SinglePartitionRouter, @@ -638,6 +642,7 @@ def _init_mappings(self) -> None: UnlimitedCallRatePolicyModel: self.create_unlimited_call_rate_policy, RateModel: self.create_rate, HttpRequestRegexMatcherModel: self.create_http_request_matcher, + GroupingPartitionRouterModel: self.create_grouping_partition_router, } # Needed for the case where we need to perform a second parse on the fields of a custom component @@ -1355,6 +1360,9 @@ def create_concurrent_cursor_from_perpartition_cursor( ) stream_state = self.apply_stream_state_migrations(stream_state_migrations, stream_state) + # Per-partition state doesn't make sense for GroupingPartitionRouter, so force the global state + use_global_cursor = isinstance(partition_router, GroupingPartitionRouter) + # Return the concurrent cursor and state converter return ConcurrentPerPartitionCursor( cursor_factory=cursor_factory, @@ -1366,6 +1374,7 @@ def create_concurrent_cursor_from_perpartition_cursor( connector_state_manager=state_manager, connector_state_converter=connector_state_converter, cursor_field=cursor_field, + use_global_cursor=use_global_cursor, ) @staticmethod @@ -3344,3 +3353,34 @@ def set_api_budget(self, component_definition: ComponentDefinition, config: Conf self._api_budget = self.create_component( model_type=HTTPAPIBudgetModel, component_definition=component_definition, config=config ) + + def create_grouping_partition_router( + self, model: GroupingPartitionRouterModel, config: Config, **kwargs: Any + ) -> GroupingPartitionRouter: + underlying_router = self._create_component_from_model( + model=model.underlying_partition_router, config=config + ) + if model.group_size < 1: + raise ValueError(f"Group size must be greater than 0, got {model.group_size}") + + # Request options in underlying partition routers are not supported for GroupingPartitionRouter + # because they are specific to individual partitions and cannot be aggregated or handled + # when grouping, potentially leading to incorrect API calls. Any request customization + # should be managed at the stream level through the requester's configuration. + if isinstance(underlying_router, SubstreamPartitionRouter): + if any( + parent_config.request_option + for parent_config in underlying_router.parent_stream_configs + ): + raise ValueError("Request options are not supported for GroupingPartitionRouter.") + + if isinstance(underlying_router, ListPartitionRouter): + if underlying_router.request_option: + raise ValueError("Request options are not supported for GroupingPartitionRouter.") + + return GroupingPartitionRouter( + group_size=model.group_size, + underlying_partition_router=underlying_router, + deduplicate=model.deduplicate if model.deduplicate is not None else True, + config=config, + ) diff --git a/airbyte_cdk/sources/declarative/partition_routers/__init__.py b/airbyte_cdk/sources/declarative/partition_routers/__init__.py index f35647402..2e99286d2 100644 --- a/airbyte_cdk/sources/declarative/partition_routers/__init__.py +++ b/airbyte_cdk/sources/declarative/partition_routers/__init__.py @@ -8,6 +8,9 @@ from airbyte_cdk.sources.declarative.partition_routers.cartesian_product_stream_slicer import ( CartesianProductStreamSlicer, ) +from airbyte_cdk.sources.declarative.partition_routers.grouping_partition_router import ( + GroupingPartitionRouter, +) from airbyte_cdk.sources.declarative.partition_routers.list_partition_router import ( ListPartitionRouter, ) @@ -22,6 +25,7 @@ __all__ = [ "AsyncJobPartitionRouter", "CartesianProductStreamSlicer", + "GroupingPartitionRouter", "ListPartitionRouter", "SinglePartitionRouter", "SubstreamPartitionRouter", diff --git a/airbyte_cdk/sources/declarative/partition_routers/grouping_partition_router.py b/airbyte_cdk/sources/declarative/partition_routers/grouping_partition_router.py new file mode 100644 index 000000000..a08acbbea --- /dev/null +++ b/airbyte_cdk/sources/declarative/partition_routers/grouping_partition_router.py @@ -0,0 +1,150 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +from dataclasses import dataclass +from typing import Any, Iterable, Mapping, Optional + +from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter +from airbyte_cdk.sources.types import Config, StreamSlice, StreamState + + +@dataclass +class GroupingPartitionRouter(PartitionRouter): + """ + A partition router that groups partitions from an underlying partition router into batches of a specified size. + This is useful for APIs that support filtering by multiple partition keys in a single request. + + Attributes: + group_size (int): The number of partitions to include in each group. + underlying_partition_router (PartitionRouter): The partition router whose output will be grouped. + deduplicate (bool): If True, ensures unique partitions within each group by removing duplicates based on the partition key. + config (Config): The connector configuration. + parameters (Mapping[str, Any]): Additional parameters for interpolation and configuration. + """ + + group_size: int + underlying_partition_router: PartitionRouter + config: Config + deduplicate: bool = True + + def __post_init__(self) -> None: + self._state: Optional[Mapping[str, StreamState]] = {} + + def stream_slices(self) -> Iterable[StreamSlice]: + """ + Lazily groups partitions from the underlying partition router into batches of size `group_size`. + + This method processes partitions one at a time from the underlying router, maintaining a batch buffer. + When the buffer reaches `group_size` or the underlying router is exhausted, it yields a grouped slice. + If deduplication is enabled, it tracks seen partition keys to ensure uniqueness within the current batch. + + Yields: + Iterable[StreamSlice]: An iterable of StreamSlice objects, where each slice contains a batch of partition values. + """ + batch = [] + seen_keys = set() + + # Iterate over partitions lazily from the underlying router + for partition in self.underlying_partition_router.stream_slices(): + # Extract the partition key (assuming single key-value pair, e.g., {"board_ids": value}) + partition_keys = list(partition.partition.keys()) + # skip parent_slice as it is part of SubstreamPartitionRouter partition + if "parent_slice" in partition_keys: + partition_keys.remove("parent_slice") + if len(partition_keys) != 1: + raise ValueError( + f"GroupingPartitionRouter expects a single partition key-value pair. Got {partition.partition}" + ) + key = partition.partition[partition_keys[0]] + + # Skip duplicates if deduplication is enabled + if self.deduplicate and key in seen_keys: + continue + + # Add partition to the batch + batch.append(partition) + if self.deduplicate: + seen_keys.add(key) + + # Yield the batch when it reaches the group_size + if len(batch) == self.group_size: + self._state = self.underlying_partition_router.get_stream_state() + yield self._create_grouped_slice(batch) + batch = [] # Reset the batch + + self._state = self.underlying_partition_router.get_stream_state() + # Yield any remaining partitions if the batch isn't empty + if batch: + yield self._create_grouped_slice(batch) + + def _create_grouped_slice(self, batch: list[StreamSlice]) -> StreamSlice: + """ + Creates a grouped StreamSlice from a batch of partitions, aggregating extra fields into a dictionary with list values. + + Args: + batch (list[StreamSlice]): A list of StreamSlice objects to group. + + Returns: + StreamSlice: A single StreamSlice with combined partition and extra field values. + """ + # Combine partition values into a single dict with lists + grouped_partition = { + key: [p.partition.get(key) for p in batch] for key in batch[0].partition.keys() + } + + # Aggregate extra fields into a dict with list values + extra_fields_dict = ( + { + key: [p.extra_fields.get(key) for p in batch] + for key in set().union(*(p.extra_fields.keys() for p in batch if p.extra_fields)) + } + if any(p.extra_fields for p in batch) + else {} + ) + return StreamSlice( + partition=grouped_partition, + cursor_slice={}, # Cursor is managed by the underlying router or incremental sync + extra_fields=extra_fields_dict, + ) + + def get_request_params( + self, + stream_state: Optional[StreamState] = None, + stream_slice: Optional[StreamSlice] = None, + next_page_token: Optional[Mapping[str, Any]] = None, + ) -> Mapping[str, Any]: + return {} + + def get_request_headers( + self, + stream_state: Optional[StreamState] = None, + stream_slice: Optional[StreamSlice] = None, + next_page_token: Optional[Mapping[str, Any]] = None, + ) -> Mapping[str, Any]: + return {} + + def get_request_body_data( + self, + stream_state: Optional[StreamState] = None, + stream_slice: Optional[StreamSlice] = None, + next_page_token: Optional[Mapping[str, Any]] = None, + ) -> Mapping[str, Any]: + return {} + + def get_request_body_json( + self, + stream_state: Optional[StreamState] = None, + stream_slice: Optional[StreamSlice] = None, + next_page_token: Optional[Mapping[str, Any]] = None, + ) -> Mapping[str, Any]: + return {} + + def set_initial_state(self, stream_state: StreamState) -> None: + """Delegate state initialization to the underlying partition router.""" + self.underlying_partition_router.set_initial_state(stream_state) + self._state = self.underlying_partition_router.get_stream_state() + + def get_stream_state(self) -> Optional[Mapping[str, StreamState]]: + """Delegate state retrieval to the underlying partition router.""" + return self._state diff --git a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py index b1a9cad2c..9d462f330 100644 --- a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py +++ b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py @@ -65,6 +65,9 @@ from airbyte_cdk.sources.declarative.models import DatetimeBasedCursor as DatetimeBasedCursorModel from airbyte_cdk.sources.declarative.models import DeclarativeStream as DeclarativeStreamModel from airbyte_cdk.sources.declarative.models import DefaultPaginator as DefaultPaginatorModel +from airbyte_cdk.sources.declarative.models import ( + GroupingPartitionRouter as GroupingPartitionRouterModel, +) from airbyte_cdk.sources.declarative.models import HttpRequester as HttpRequesterModel from airbyte_cdk.sources.declarative.models import JwtAuthenticator as JwtAuthenticatorModel from airbyte_cdk.sources.declarative.models import ListPartitionRouter as ListPartitionRouterModel @@ -96,6 +99,7 @@ from airbyte_cdk.sources.declarative.partition_routers import ( AsyncJobPartitionRouter, CartesianProductStreamSlicer, + GroupingPartitionRouter, ListPartitionRouter, SinglePartitionRouter, SubstreamPartitionRouter, @@ -3840,3 +3844,156 @@ def test_api_budget_fixed_window_policy(): assert matcher._method == "GET" assert matcher._url_base == "https://example.org" assert matcher._url_path_pattern.pattern == "/v2/data" + + +def test_create_grouping_partition_router_with_underlying_router(): + content = """ + schema_loader: + file_path: "./source_example/schemas/{{ parameters['name'] }}.yaml" + name: "{{ parameters['stream_name'] }}" + retriever: + requester: + type: "HttpRequester" + path: "example" + record_selector: + extractor: + field_path: [] + stream_A: + type: DeclarativeStream + name: "A" + primary_key: "id" + $parameters: + retriever: "#/retriever" + url_base: "https://airbyte.io" + schema_loader: "#/schema_loader" + sub_partition_router: + type: SubstreamPartitionRouter + parent_stream_configs: + - stream: "#/stream_A" + parent_key: id + partition_field: repository_id + partition_router: + type: GroupingPartitionRouter + underlying_partition_router: "#/sub_partition_router" + group_size: 2 + """ + parsed_manifest = YamlDeclarativeSource._parse(content) + resolved_manifest = resolver.preprocess_manifest(parsed_manifest) + partition_router_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["partition_router"], {} + ) + + partition_router = factory.create_component( + model_type=GroupingPartitionRouterModel, + component_definition=partition_router_manifest, + config=input_config, + ) + + # Test the created partition router + assert isinstance(partition_router, GroupingPartitionRouter) + assert isinstance(partition_router.underlying_partition_router, SubstreamPartitionRouter) + assert partition_router.group_size == 2 + + # Test the underlying partition router + parent_stream_configs = partition_router.underlying_partition_router.parent_stream_configs + assert len(parent_stream_configs) == 1 + assert isinstance(parent_stream_configs[0].stream, DeclarativeStream) + assert parent_stream_configs[0].parent_key.eval({}) == "id" + assert parent_stream_configs[0].partition_field.eval({}) == "repository_id" + + +def test_create_grouping_partition_router_invalid_group_size(): + """Test that an invalid group_size (< 1) raises a ValueError.""" + content = """ + schema_loader: + file_path: "./source_example/schemas/{{ parameters['name'] }}.yaml" + name: "{{ parameters['stream_name'] }}" + retriever: + requester: + type: "HttpRequester" + path: "example" + record_selector: + extractor: + field_path: [] + stream_A: + type: DeclarativeStream + name: "A" + primary_key: "id" + $parameters: + retriever: "#/retriever" + url_base: "https://airbyte.io" + schema_loader: "#/schema_loader" + sub_partition_router: + type: SubstreamPartitionRouter + parent_stream_configs: + - stream: "#/stream_A" + parent_key: id + partition_field: repository_id + partition_router: + type: GroupingPartitionRouter + underlying_partition_router: "#/sub_partition_router" + group_size: 0 + """ + parsed_manifest = YamlDeclarativeSource._parse(content) + resolved_manifest = resolver.preprocess_manifest(parsed_manifest) + partition_router_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["partition_router"], {} + ) + + with pytest.raises(ValueError, match="Group size must be greater than 0, got 0"): + factory.create_component( + model_type=GroupingPartitionRouterModel, + component_definition=partition_router_manifest, + config=input_config, + ) + + +def test_create_grouping_partition_router_substream_with_request_option(): + """Test that a SubstreamPartitionRouter with request_option raises a ValueError.""" + content = """ + schema_loader: + file_path: "./source_example/schemas/{{ parameters['name'] }}.yaml" + name: "{{ parameters['stream_name'] }}" + retriever: + requester: + type: "HttpRequester" + path: "example" + record_selector: + extractor: + field_path: [] + stream_A: + type: DeclarativeStream + name: "A" + primary_key: "id" + $parameters: + retriever: "#/retriever" + url_base: "https://airbyte.io" + schema_loader: "#/schema_loader" + sub_partition_router: + type: SubstreamPartitionRouter + parent_stream_configs: + - stream: "#/stream_A" + parent_key: id + partition_field: repository_id + request_option: + inject_into: request_parameter + field_name: "repo_id" + partition_router: + type: GroupingPartitionRouter + underlying_partition_router: "#/sub_partition_router" + group_size: 2 + """ + parsed_manifest = YamlDeclarativeSource._parse(content) + resolved_manifest = resolver.preprocess_manifest(parsed_manifest) + partition_router_manifest = transformer.propagate_types_and_parameters( + "", resolved_manifest["partition_router"], {} + ) + + with pytest.raises( + ValueError, match="Request options are not supported for GroupingPartitionRouter." + ): + factory.create_component( + model_type=GroupingPartitionRouterModel, + component_definition=partition_router_manifest, + config=input_config, + ) diff --git a/unit_tests/sources/declarative/partition_routers/test_grouping_partition_router.py b/unit_tests/sources/declarative/partition_routers/test_grouping_partition_router.py new file mode 100644 index 000000000..a75a48966 --- /dev/null +++ b/unit_tests/sources/declarative/partition_routers/test_grouping_partition_router.py @@ -0,0 +1,514 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +from typing import Any, Iterable, List, Mapping, Optional, Union +from unittest.mock import MagicMock + +import pytest + +from airbyte_cdk.sources.declarative.partition_routers import ( + GroupingPartitionRouter, + SubstreamPartitionRouter, +) +from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import ( + ParentStreamConfig, +) +from airbyte_cdk.sources.types import StreamSlice +from unit_tests.sources.declarative.partition_routers.test_substream_partition_router import ( + MockStream, + parent_slices, +) # Reuse MockStream and parent_slices + + +@pytest.fixture +def mock_config(): + return {} + + +@pytest.fixture +def mock_underlying_router(mock_config): + """Fixture for a simple underlying router with predefined slices and extra fields.""" + parent_stream = MockStream( + slices=[{}], # Single empty slice, parent_partition will be {} + records=[ + {"board_id": 0, "name": "Board 0", "owner": "User0"}, + { + "board_id": 0, + "name": "Board 0 Duplicate", + "owner": "User0 Duplicate", + }, # Duplicate board_id + ] + + [{"board_id": i, "name": f"Board {i}", "owner": f"User{i}"} for i in range(1, 5)], + name="mock_parent", + ) + return SubstreamPartitionRouter( + parent_stream_configs=[ + ParentStreamConfig( + stream=parent_stream, + parent_key="board_id", + partition_field="board_ids", + config=mock_config, + parameters={}, + extra_fields=[["name"], ["owner"]], + ) + ], + config=mock_config, + parameters={}, + ) + + +@pytest.fixture +def mock_underlying_router_with_parent_slices(mock_config): + """Fixture with varied parent slices for testing non-empty parent_slice.""" + parent_stream = MockStream( + slices=parent_slices, # [{"slice": "first"}, {"slice": "second"}, {"slice": "third"}] + records=[ + {"board_id": 1, "name": "Board 1", "owner": "User1", "slice": "first"}, + {"board_id": 2, "name": "Board 2", "owner": "User2", "slice": "second"}, + {"board_id": 3, "name": "Board 3", "owner": "User3", "slice": "third"}, + ], + name="mock_parent", + ) + return SubstreamPartitionRouter( + parent_stream_configs=[ + ParentStreamConfig( + stream=parent_stream, + parent_key="board_id", + partition_field="board_ids", + config=mock_config, + parameters={}, + extra_fields=[["name"], ["owner"]], + ) + ], + config=mock_config, + parameters={}, + ) + + +@pytest.mark.parametrize( + "group_size, deduplicate, expected_slices", + [ + ( + 2, + True, + [ + StreamSlice( + partition={"board_ids": [0, 1], "parent_slice": [{}, {}]}, + cursor_slice={}, + extra_fields={"name": ["Board 0", "Board 1"], "owner": ["User0", "User1"]}, + ), + StreamSlice( + partition={"board_ids": [2, 3], "parent_slice": [{}, {}]}, + cursor_slice={}, + extra_fields={"name": ["Board 2", "Board 3"], "owner": ["User2", "User3"]}, + ), + StreamSlice( + partition={"board_ids": [4], "parent_slice": [{}]}, + cursor_slice={}, + extra_fields={"name": ["Board 4"], "owner": ["User4"]}, + ), + ], + ), + ( + 3, + True, + [ + StreamSlice( + partition={"board_ids": [0, 1, 2], "parent_slice": [{}, {}, {}]}, + cursor_slice={}, + extra_fields={ + "name": ["Board 0", "Board 1", "Board 2"], + "owner": ["User0", "User1", "User2"], + }, + ), + StreamSlice( + partition={"board_ids": [3, 4], "parent_slice": [{}, {}]}, + cursor_slice={}, + extra_fields={"name": ["Board 3", "Board 4"], "owner": ["User3", "User4"]}, + ), + ], + ), + ( + 2, + False, + [ + StreamSlice( + partition={"board_ids": [0, 0], "parent_slice": [{}, {}]}, + cursor_slice={}, + extra_fields={ + "name": ["Board 0", "Board 0 Duplicate"], + "owner": ["User0", "User0 Duplicate"], + }, + ), + StreamSlice( + partition={"board_ids": [1, 2], "parent_slice": [{}, {}]}, + cursor_slice={}, + extra_fields={"name": ["Board 1", "Board 2"], "owner": ["User1", "User2"]}, + ), + StreamSlice( + partition={"board_ids": [3, 4], "parent_slice": [{}, {}]}, + cursor_slice={}, + extra_fields={"name": ["Board 3", "Board 4"], "owner": ["User3", "User4"]}, + ), + ], + ), + ], + ids=["group_size_2_deduplicate", "group_size_3_deduplicate", "group_size_2_no_deduplicate"], +) +def test_stream_slices_grouping( + mock_config, mock_underlying_router, group_size, deduplicate, expected_slices +): + """Test basic grouping behavior with different group sizes and deduplication settings.""" + router = GroupingPartitionRouter( + group_size=group_size, + underlying_partition_router=mock_underlying_router, + deduplicate=deduplicate, + config=mock_config, + ) + slices = list(router.stream_slices()) + assert slices == expected_slices + + +def test_stream_slices_empty_underlying_router(mock_config): + """Test behavior when the underlying router yields no slices.""" + parent_stream = MockStream( + slices=[{}], + records=[], + name="mock_parent", + ) + underlying_router = SubstreamPartitionRouter( + parent_stream_configs=[ + ParentStreamConfig( + stream=parent_stream, + parent_key="board_id", + partition_field="board_ids", + config=mock_config, + parameters={}, + extra_fields=[["name"]], + ) + ], + config=mock_config, + parameters={}, + ) + router = GroupingPartitionRouter( + group_size=2, + underlying_partition_router=underlying_router, + config=mock_config, + ) + slices = list(router.stream_slices()) + assert slices == [] + + +def test_stream_slices_lazy_iteration(mock_config, mock_underlying_router): + """Test that stream_slices processes partitions lazily, iterating the underlying router as an iterator.""" + + # Custom iterator to track yields and simulate underlying stream_slices + class ControlledIterator: + def __init__(self): + self.slices = [ + StreamSlice( + partition={"board_ids": 0}, + cursor_slice={}, + extra_fields={"name": "Board 0", "owner": "User0"}, + ), + StreamSlice( + partition={"board_ids": 1}, + cursor_slice={}, + extra_fields={"name": "Board 1", "owner": "User1"}, + ), + StreamSlice( + partition={"board_ids": 2}, + cursor_slice={}, + extra_fields={"name": "Board 2", "owner": "User2"}, + ), + StreamSlice( + partition={"board_ids": 3}, + cursor_slice={}, + extra_fields={"name": "Board 3", "owner": "User3"}, + ), + StreamSlice( + partition={"board_ids": 4}, + cursor_slice={}, + extra_fields={"name": "Board 4", "owner": "User4"}, + ), + ] + self.index = 0 + self.yield_count = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.index < len(self.slices): + self.yield_count += 1 + slice = self.slices[self.index] + self.index += 1 + return slice + raise StopIteration + + # Replace the underlying router's stream_slices with the controlled iterator + controlled_iter = ControlledIterator() + mock_underlying_router.stream_slices = MagicMock(return_value=controlled_iter) + + router = GroupingPartitionRouter( + group_size=2, + underlying_partition_router=mock_underlying_router, + config=mock_config, + deduplicate=True, + ) + slices_iter = router.stream_slices() + + # Before iteration, no slices should be yielded + assert controlled_iter.yield_count == 0, "No slices should be yielded before iteration starts" + + # Get the first slice + first_slice = next(slices_iter) + assert first_slice == StreamSlice( + partition={"board_ids": [0, 1]}, + cursor_slice={}, + extra_fields={"name": ["Board 0", "Board 1"], "owner": ["User0", "User1"]}, + ) + assert ( + controlled_iter.yield_count == 2 + ), "Only 2 slices should be yielded to form the first group" + + # Get the second slice + second_slice = next(slices_iter) + assert second_slice == StreamSlice( + partition={"board_ids": [2, 3]}, + cursor_slice={}, + extra_fields={"name": ["Board 2", "Board 3"], "owner": ["User2", "User3"]}, + ) + assert ( + controlled_iter.yield_count == 4 + ), "Only 4 slices should be yielded up to the second group" + + # Exhaust the iterator + remaining_slices = list(slices_iter) + assert remaining_slices == [ + StreamSlice( + partition={"board_ids": [4]}, + cursor_slice={}, + extra_fields={"name": ["Board 4"], "owner": ["User4"]}, + ) + ] + assert ( + controlled_iter.yield_count == 5 + ), "All 5 slices should be yielded after exhausting the iterator" + + +def test_set_initial_state_delegation(mock_config, mock_underlying_router): + """Test that set_initial_state delegates to the underlying router.""" + router = GroupingPartitionRouter( + group_size=2, + underlying_partition_router=mock_underlying_router, + config=mock_config, + ) + mock_state = {"some_key": "some_value"} + mock_underlying_router.set_initial_state = MagicMock() + + router.set_initial_state(mock_state) + mock_underlying_router.set_initial_state.assert_called_once_with(mock_state) + + +def test_stream_slices_extra_fields_varied(mock_config): + """Test grouping with varied extra fields across partitions.""" + parent_stream = MockStream( + slices=[{}], + records=[ + {"board_id": 1, "name": "Board 1", "owner": "User1"}, + {"board_id": 2, "name": "Board 2"}, # Missing owner + {"board_id": 3, "owner": "User3"}, # Missing name + ], + name="mock_parent", + ) + underlying_router = SubstreamPartitionRouter( + parent_stream_configs=[ + ParentStreamConfig( + stream=parent_stream, + parent_key="board_id", + partition_field="board_ids", + config=mock_config, + parameters={}, + extra_fields=[["name"], ["owner"]], + ) + ], + config=mock_config, + parameters={}, + ) + router = GroupingPartitionRouter( + group_size=2, + underlying_partition_router=underlying_router, + config=mock_config, + deduplicate=True, + ) + expected_slices = [ + StreamSlice( + partition={"board_ids": [1, 2], "parent_slice": [{}, {}]}, + cursor_slice={}, + extra_fields={"name": ["Board 1", "Board 2"], "owner": ["User1", None]}, + ), + StreamSlice( + partition={"board_ids": [3], "parent_slice": [{}]}, + cursor_slice={}, + extra_fields={"name": [None], "owner": ["User3"]}, + ), + ] + slices = list(router.stream_slices()) + assert slices == expected_slices + + +def test_grouping_with_complex_partitions_and_extra_fields(mock_config): + """Test grouping with partitions containing multiple keys and extra fields.""" + parent_stream = MockStream( + slices=[{}], + records=[{"board_id": i, "extra": f"extra_{i}", "name": f"Board {i}"} for i in range(3)], + name="mock_parent", + ) + underlying_router = SubstreamPartitionRouter( + parent_stream_configs=[ + ParentStreamConfig( + stream=parent_stream, + parent_key="board_id", + partition_field="board_ids", + config=mock_config, + parameters={}, + extra_fields=[["extra"], ["name"]], + ) + ], + config=mock_config, + parameters={}, + ) + router = GroupingPartitionRouter( + group_size=2, + underlying_partition_router=underlying_router, + config=mock_config, + ) + expected_slices = [ + StreamSlice( + partition={"board_ids": [0, 1], "parent_slice": [{}, {}]}, + cursor_slice={}, + extra_fields={"extra": ["extra_0", "extra_1"], "name": ["Board 0", "Board 1"]}, + ), + StreamSlice( + partition={"board_ids": [2], "parent_slice": [{}]}, + cursor_slice={}, + extra_fields={"extra": ["extra_2"], "name": ["Board 2"]}, + ), + ] + slices = list(router.stream_slices()) + assert slices == expected_slices + + +def test_stream_slices_with_non_empty_parent_slice( + mock_config, mock_underlying_router_with_parent_slices +): + """Test grouping with non-empty parent_slice values from the underlying router.""" + router = GroupingPartitionRouter( + group_size=2, + underlying_partition_router=mock_underlying_router_with_parent_slices, + config=mock_config, + deduplicate=True, + ) + expected_slices = [ + StreamSlice( + partition={ + "board_ids": [1, 2], + "parent_slice": [{"slice": "first"}, {"slice": "second"}], + }, + cursor_slice={}, + extra_fields={"name": ["Board 1", "Board 2"], "owner": ["User1", "User2"]}, + ), + StreamSlice( + partition={"board_ids": [3], "parent_slice": [{"slice": "third"}]}, + cursor_slice={}, + extra_fields={"name": ["Board 3"], "owner": ["User3"]}, + ), + ] + slices = list(router.stream_slices()) + assert slices == expected_slices + + +def test_get_request_params_default(mock_config, mock_underlying_router): + """Test that get_request_params returns an empty dict by default.""" + router = GroupingPartitionRouter( + group_size=2, + underlying_partition_router=mock_underlying_router, + config=mock_config, + ) + params = router.get_request_params( + stream_slice=StreamSlice( + partition={"board_ids": [1, 2], "parent_slice": [{}, {}]}, cursor_slice={} + ) + ) + assert params == {} + + +def test_stream_slices_resume_from_state(mock_config, mock_underlying_router): + """Test that stream_slices resumes correctly from a previous state.""" + + # Simulate underlying router state handling + class MockPartitionRouter: + def __init__(self): + self.slices = [ + StreamSlice( + partition={"board_ids": i}, + cursor_slice={}, + extra_fields={"name": f"Board {i}", "owner": f"User{i}"}, + ) + for i in range(5) + ] + self.state = {"last_board_id": 0} # Initial state + + def set_initial_state(self, state): + self.state = state + + def get_stream_state(self): + return self.state + + def stream_slices(self): + last_board_id = self.state.get("last_board_id", -1) + for slice in self.slices: + board_id = slice.partition["board_ids"] + if board_id <= last_board_id: + continue + self.state = {"last_board_id": board_id} + yield slice + + underlying_router = MockPartitionRouter() + router = GroupingPartitionRouter( + group_size=2, + underlying_partition_router=underlying_router, + config=mock_config, + deduplicate=True, + ) + + # First sync: process first two slices + router.set_initial_state({"last_board_id": 0}) + slices_iter = router.stream_slices() + first_batch = next(slices_iter) + assert first_batch == StreamSlice( + partition={"board_ids": [1, 2]}, + cursor_slice={}, + extra_fields={"name": ["Board 1", "Board 2"], "owner": ["User1", "User2"]}, + ) + state_after_first = router.get_stream_state() + assert state_after_first == {"last_board_id": 2}, "State should reflect last processed board_id" + + # Simulate a new sync resuming from the previous state + new_router = GroupingPartitionRouter( + group_size=2, + underlying_partition_router=MockPartitionRouter(), + config=mock_config, + deduplicate=True, + ) + new_router.set_initial_state(state_after_first) + resumed_slices = list(new_router.stream_slices()) + assert resumed_slices == [ + StreamSlice( + partition={"board_ids": [3, 4]}, + cursor_slice={}, + extra_fields={"name": ["Board 3", "Board 4"], "owner": ["User3", "User4"]}, + ) + ], "Should resume from board_id 3"