diff --git a/airbyte_cdk/sources/declarative/requesters/http_job_repository.py b/airbyte_cdk/sources/declarative/requesters/http_job_repository.py index e8bca6cc9..8e77c3409 100644 --- a/airbyte_cdk/sources/declarative/requesters/http_job_repository.py +++ b/airbyte_cdk/sources/declarative/requesters/http_job_repository.py @@ -320,14 +320,14 @@ def _get_polling_response_interpolation_context(self, job: AsyncJob) -> Dict[str return polling_response_context def _get_create_job_stream_slice(self, job: AsyncJob) -> StreamSlice: - stream_slice = StreamSlice( - partition={}, - cursor_slice={}, - extra_fields={ + return StreamSlice( + partition=job.job_parameters().partition, + cursor_slice=job.job_parameters().cursor_slice, + extra_fields=dict(job.job_parameters().extra_fields) + | { "creation_response": self._get_creation_response_interpolation_context(job), }, ) - return stream_slice def _get_download_targets(self, job: AsyncJob) -> Iterable[str]: if not self.download_target_requester: diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py index 4e175bb28..a363f9edc 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py @@ -11,6 +11,7 @@ ) from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.types import Config, StreamSlice +from airbyte_cdk.utils.mapping_helpers import get_interpolation_context @dataclass @@ -52,8 +53,8 @@ def eval_request_inputs( :param next_page_token: The pagination token :return: The request inputs to set on an outgoing HTTP request """ - kwargs = { - "stream_slice": stream_slice, - "next_page_token": next_page_token, - } + kwargs = get_interpolation_context( + stream_slice=stream_slice, + next_page_token=next_page_token, + ) return self._interpolator.eval(self.config, **kwargs) # type: ignore # self._interpolator is always initialized with a value and will not be None diff --git a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py index ed0e54c60..dfe8d6460 100644 --- a/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py +++ b/airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py @@ -8,6 +8,7 @@ from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.types import Config, StreamSlice, StreamState +from airbyte_cdk.utils.mapping_helpers import get_interpolation_context @dataclass @@ -51,10 +52,10 @@ def eval_request_inputs( :param valid_value_types: A tuple of types that the interpolator should allow :return: The request inputs to set on an outgoing HTTP request """ - kwargs = { - "stream_slice": stream_slice, - "next_page_token": next_page_token, - } + kwargs = get_interpolation_context( + stream_slice=stream_slice, + next_page_token=next_page_token, + ) interpolated_value = self._interpolator.eval( # type: ignore # self._interpolator is always initialized with a value and will not be None self.config, valid_key_types=valid_key_types, diff --git a/unit_tests/sources/declarative/requesters/test_http_job_repository.py b/unit_tests/sources/declarative/requesters/test_http_job_repository.py index 4be3ecb11..473c3d99e 100644 --- a/unit_tests/sources/declarative/requesters/test_http_job_repository.py +++ b/unit_tests/sources/declarative/requesters/test_http_job_repository.py @@ -2,6 +2,7 @@ import json +from typing import Optional from unittest import TestCase from unittest.mock import Mock @@ -28,6 +29,8 @@ ) from airbyte_cdk.sources.declarative.requesters.requester import HttpMethod from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever +from airbyte_cdk.sources.message import MessageRepository +from airbyte_cdk.sources.streams.http.error_handlers import ErrorHandler from airbyte_cdk.sources.types import StreamSlice from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse @@ -45,111 +48,12 @@ a_record_id,a_value """ _A_CURSOR_FOR_PAGINATION = "a-cursor-for-pagination" +_ERROR_HANDLER = DefaultErrorHandler(config=_ANY_CONFIG, parameters={}) class HttpJobRepositoryTest(TestCase): def setUp(self) -> None: - message_repository = Mock() - error_handler = DefaultErrorHandler(config=_ANY_CONFIG, parameters={}) - - self._create_job_requester = HttpRequester( - name="stream : create_job", - url_base=_URL_BASE, - path=_EXPORT_PATH, - error_handler=error_handler, - http_method=HttpMethod.POST, - config=_ANY_CONFIG, - disable_retries=False, - parameters={}, - message_repository=message_repository, - use_cache=False, - stream_response=False, - ) - - self._polling_job_requester = HttpRequester( - name="stream : polling", - url_base=_URL_BASE, - path=_EXPORT_PATH + "/{{creation_response['id']}}", - error_handler=error_handler, - http_method=HttpMethod.GET, - config=_ANY_CONFIG, - disable_retries=False, - parameters={}, - message_repository=message_repository, - use_cache=False, - stream_response=False, - ) - - self._download_retriever = SimpleRetriever( - requester=HttpRequester( - name="stream : fetch_result", - url_base="", - path="{{download_target}}", - error_handler=error_handler, - http_method=HttpMethod.GET, - config=_ANY_CONFIG, - disable_retries=False, - parameters={}, - message_repository=message_repository, - use_cache=False, - stream_response=True, - ), - record_selector=RecordSelector( - extractor=ResponseToFileExtractor({}), - record_filter=None, - transformations=[], - schema_normalization=TypeTransformer(TransformConfig.NoTransform), - config=_ANY_CONFIG, - parameters={}, - ), - primary_key=None, - name="any name", - paginator=DefaultPaginator( - decoder=NoopDecoder(), - page_size_option=None, - page_token_option=RequestOption( - field_name="locator", - inject_into=RequestOptionType.request_parameter, - parameters={}, - ), - pagination_strategy=CursorPaginationStrategy( - cursor_value="{{ headers['Sforce-Locator'] }}", - decoder=NoopDecoder(), - config=_ANY_CONFIG, - parameters={}, - ), - url_base=_URL_BASE, - config=_ANY_CONFIG, - parameters={}, - ), - config=_ANY_CONFIG, - parameters={}, - ) - - self._repository = AsyncHttpJobRepository( - creation_requester=self._create_job_requester, - polling_requester=self._polling_job_requester, - download_retriever=self._download_retriever, - abort_requester=None, - delete_requester=None, - status_extractor=DpathExtractor( - decoder=JsonDecoder(parameters={}), - field_path=["status"], - config={}, - parameters={} or {}, - ), - status_mapping={ - "ready": AsyncJobStatus.COMPLETED, - "failure": AsyncJobStatus.FAILED, - "pending": AsyncJobStatus.RUNNING, - }, - download_target_extractor=DpathExtractor( - decoder=JsonDecoder(parameters={}), - field_path=["urls"], - config={}, - parameters={} or {}, - ), - ) + self._repository = self._create_async_job_repository() self._http_mocker = HttpMocker() self._http_mocker.__enter__() @@ -178,6 +82,35 @@ def test_given_different_statuses_when_update_jobs_status_then_update_status_pro self._repository.update_jobs_status([job]) assert job.status() == AsyncJobStatus.COMPLETED + def test_when_update_jobs_status_then_allow_access_to_stream_slice_information(self) -> None: + stream_slice = StreamSlice(partition={"path": "path_from_slice"}, cursor_slice={}) + self._mock_create_response(_A_JOB_ID) + self._http_mocker.get( + HttpRequest(url=f"{_EXPORT_URL}/{stream_slice['path']}/{_A_JOB_ID}"), + HttpResponse(body=json.dumps({"id": _A_JOB_ID, "status": "ready"})), + ) + repository = self._create_async_job_repository( + HttpRequester( + name="stream : polling", + url_base=_URL_BASE, + path=_EXPORT_PATH + "/{{stream_slice['path']}}/{{creation_response['id']}}", + error_handler=_ERROR_HANDLER, + http_method=HttpMethod.GET, + config=_ANY_CONFIG, + disable_retries=False, + parameters={}, + message_repository=Mock(), + # this might not align with the rest of the components in async job repository but if message_repository becomes important for tests, please share this instance with the other components + use_cache=False, + stream_response=False, + ) + ) + + job = repository.start(stream_slice) + repository.update_jobs_status([job]) + + assert job.status() == AsyncJobStatus.COMPLETED + def test_given_unknown_status_when_update_jobs_status_then_raise_error(self) -> None: self._mock_create_response(_A_JOB_ID) self._http_mocker.get( @@ -277,3 +210,109 @@ def _mock_create_response(self, job_id: str) -> None: HttpRequest(url=_EXPORT_URL), HttpResponse(body=json.dumps({"id": job_id})), ) + + def _create_async_job_repository( + self, polling_job_requester: Optional[HttpRequester] = None + ) -> AsyncHttpJobRepository: + message_repository = Mock() + create_job_requester = HttpRequester( + name="stream : create_job", + url_base=_URL_BASE, + path=_EXPORT_PATH, + error_handler=_ERROR_HANDLER, + http_method=HttpMethod.POST, + config=_ANY_CONFIG, + disable_retries=False, + parameters={}, + message_repository=message_repository, + use_cache=False, + stream_response=False, + ) + polling_job_requester = ( + polling_job_requester + if polling_job_requester + else HttpRequester( + name="stream : polling", + url_base=_URL_BASE, + path=_EXPORT_PATH + "/{{creation_response['id']}}", + error_handler=_ERROR_HANDLER, + http_method=HttpMethod.GET, + config=_ANY_CONFIG, + disable_retries=False, + parameters={}, + message_repository=message_repository, + use_cache=False, + stream_response=False, + ) + ) + + download_retriever = SimpleRetriever( + requester=HttpRequester( + name="stream : fetch_result", + url_base="", + path="{{download_target}}", + error_handler=_ERROR_HANDLER, + http_method=HttpMethod.GET, + config=_ANY_CONFIG, + disable_retries=False, + parameters={}, + message_repository=message_repository, + use_cache=False, + stream_response=True, + ), + record_selector=RecordSelector( + extractor=ResponseToFileExtractor({}), + record_filter=None, + transformations=[], + schema_normalization=TypeTransformer(TransformConfig.NoTransform), + config=_ANY_CONFIG, + parameters={}, + ), + primary_key=None, + name="any name", + paginator=DefaultPaginator( + decoder=NoopDecoder(), + page_size_option=None, + page_token_option=RequestOption( + field_name="locator", + inject_into=RequestOptionType.request_parameter, + parameters={}, + ), + pagination_strategy=CursorPaginationStrategy( + cursor_value="{{ headers['Sforce-Locator'] }}", + decoder=NoopDecoder(), + config=_ANY_CONFIG, + parameters={}, + ), + url_base=_URL_BASE, + config=_ANY_CONFIG, + parameters={}, + ), + config=_ANY_CONFIG, + parameters={}, + ) + + return AsyncHttpJobRepository( + creation_requester=create_job_requester, + polling_requester=polling_job_requester, + download_retriever=download_retriever, + abort_requester=None, + delete_requester=None, + status_extractor=DpathExtractor( + decoder=JsonDecoder(parameters={}), + field_path=["status"], + config={}, + parameters={} or {}, + ), + status_mapping={ + "ready": AsyncJobStatus.COMPLETED, + "failure": AsyncJobStatus.FAILED, + "pending": AsyncJobStatus.RUNNING, + }, + download_target_extractor=DpathExtractor( + decoder=JsonDecoder(parameters={}), + field_path=["urls"], + config={}, + parameters={} or {}, + ), + )