diff --git a/airbyte_cdk/__init__.py b/airbyte_cdk/__init__.py index 262d162cc..41629111d 100644 --- a/airbyte_cdk/__init__.py +++ b/airbyte_cdk/__init__.py @@ -48,12 +48,14 @@ # Once those issues are resolved, the below can be sorted with isort. import dunamai as _dunamai +from airbyte_cdk.sources.abstract_source import AbstractSource +from airbyte_cdk.sources.source import Source + +# from airbyte_cdk.destinations.destination import Destination from .config_observation import ( create_connector_config_control_message, emit_configuration_as_airbyte_control_message, ) -from .connector import BaseConnector, Connector -from .destinations import Destination from .entrypoint import AirbyteEntrypoint, launch from .logger import AirbyteLogFormatter, init_logger from .models import ( @@ -75,7 +77,6 @@ SyncMode, Type, ) -from .sources import AbstractSource, Source from .sources.concurrent_source.concurrent_source import ConcurrentSource from .sources.concurrent_source.concurrent_source_adapter import ConcurrentSourceAdapter from .sources.config import BaseConfig @@ -212,7 +213,6 @@ "AbstractSource", "BaseConfig", "BaseConnector", - "Connector", "Destination", "Source", "TState", diff --git a/airbyte_cdk/cli/source_declarative_manifest/_run.py b/airbyte_cdk/cli/source_declarative_manifest/_run.py index df36b3df1..ceea51e22 100644 --- a/airbyte_cdk/cli/source_declarative_manifest/_run.py +++ b/airbyte_cdk/cli/source_declarative_manifest/_run.py @@ -45,6 +45,7 @@ ) from airbyte_cdk.sources.declarative.yaml_declarative_source import YamlDeclarativeSource from airbyte_cdk.sources.source import TState +from airbyte_cdk.utils.cli_arg_parse import parse_cli_args from airbyte_cdk.utils.datetime_helpers import ab_datetime_now @@ -93,7 +94,7 @@ def handle_command(args: list[str]) -> None: def _get_local_yaml_source(args: list[str]) -> SourceLocalYaml: try: - parsed_args = AirbyteEntrypoint.parse_args(args) + parsed_args = parse_cli_args(args) config, catalog, state = _parse_inputs_into_config_catalog_state(parsed_args) return SourceLocalYaml(config=config, catalog=catalog, state=state) except Exception as error: @@ -119,10 +120,7 @@ def _get_local_yaml_source(args: list[str]) -> SourceLocalYaml: def handle_local_manifest_command(args: list[str]) -> None: source = _get_local_yaml_source(args) - launch( - source=source, - args=args, - ) + source.launch_with_cli_args(args) def handle_remote_manifest_command(args: list[str]) -> None: @@ -149,10 +147,7 @@ def handle_remote_manifest_command(args: list[str]) -> None: print(AirbyteEntrypoint.airbyte_message_to_string(message)) else: source = create_declarative_source(args) - launch( - source=source, - args=args, - ) + source.launch_with_cli_args(args=args) def create_declarative_source( @@ -169,7 +164,7 @@ def create_declarative_source( catalog: ConfiguredAirbyteCatalog | None state: list[AirbyteStateMessage] - parsed_args = AirbyteEntrypoint.parse_args(args) + parsed_args = parse_cli_args(args) config, catalog, state = _parse_inputs_into_config_catalog_state(parsed_args) if config is None: diff --git a/airbyte_cdk/connector.py b/airbyte_cdk/connector.py index a7c13dcda..bfea0d78f 100644 --- a/airbyte_cdk/connector.py +++ b/airbyte_cdk/connector.py @@ -8,18 +8,20 @@ import os import pkgutil from abc import ABC, abstractmethod -from typing import Any, Generic, Mapping, MutableMapping, Optional, Protocol, TypeVar +from collections.abc import MutableMapping +from pathlib import Path +from typing import Any, Generic, Mapping, Optional, TypeVar import yaml +from typing_extensions import Self, deprecated -from airbyte_cdk.models import ( - AirbyteConnectionStatus, - ConnectorSpecification, - ConnectorSpecificationSerializer, -) +from airbyte_cdk.models import AirbyteConnectionStatus +from airbyte_cdk.models.airbyte_protocol import AirbyteMessage, ConnectorSpecification, Type +from airbyte_cdk.sources.message.repository import MessageRepository, PassthroughMessageRepository +from airbyte_cdk.utils.cli_arg_parse import ConnectorCLIArgs, parse_cli_args -def load_optional_package_file(package: str, filename: str) -> Optional[bytes]: +def _load_optional_package_file(package: str, filename: str) -> Optional[bytes]: """Gets a resource from a package, returning None if it does not exist""" try: return pkgutil.get_data(package, filename) @@ -27,6 +29,32 @@ def load_optional_package_file(package: str, filename: str) -> Optional[bytes]: return None +def _write_config(config: Mapping[str, Any], config_path: str) -> None: + Path(config_path).write_text(json.dumps(config)) + + +def _read_json_file(file_path: str) -> Any: + with open(file_path, "r") as file: + contents = file.read() + + try: + return json.loads(contents) + except json.JSONDecodeError as error: + raise ValueError( + f"Could not read json file {file_path}: {error}. Please ensure that it is a valid JSON." + ) + + +def _read_config(config_path: str) -> MutableMapping[str, Any]: + config = _read_json_file(config_path) + if isinstance(config, MutableMapping): + return config + else: + raise ValueError( + f"The content of {config_path} is not an object and therefore is not a valid config. Please ensure the file represent a config." + ) + + TConfig = TypeVar("TConfig", bound=Mapping[str, Any]) @@ -34,38 +62,31 @@ class BaseConnector(ABC, Generic[TConfig]): # configure whether the `check_config_against_spec_or_exit()` needs to be called check_config_against_spec: bool = True - @abstractmethod - def configure(self, config: Mapping[str, Any], temp_dir: str) -> TConfig: - """ - Persist config in temporary directory to run the Source job - """ + @classmethod + def to_typed_config( + cls, + config: Mapping[str, Any], + ) -> TConfig: + """Return a typed config object from a config dictionary.""" + ... @staticmethod def read_config(config_path: str) -> MutableMapping[str, Any]: - config = BaseConnector._read_json_file(config_path) - if isinstance(config, MutableMapping): - return config - else: - raise ValueError( - f"The content of {config_path} is not an object and therefore is not a valid config. Please ensure the file represent a config." - ) + return _read_config(config_path) @staticmethod def _read_json_file(file_path: str) -> Any: - with open(file_path, "r") as file: - contents = file.read() - - try: - return json.loads(contents) - except json.JSONDecodeError as error: - raise ValueError( - f"Could not read json file {file_path}: {error}. Please ensure that it is a valid JSON." - ) + return _read_json_file(file_path) @staticmethod def write_config(config: TConfig, config_path: str) -> None: - with open(config_path, "w") as fh: - fh.write(json.dumps(config)) + _write_config(config, config_path) + + @classmethod + def configure(cls, config: Mapping[str, Any], temp_dir: str) -> TConfig: + config_path = os.path.join(temp_dir, "config.json") + _write_config(config, config_path) + return cls.to_typed_config(config) def spec(self, logger: logging.Logger) -> ConnectorSpecification: """ @@ -75,8 +96,8 @@ def spec(self, logger: logging.Logger) -> ConnectorSpecification: package = self.__class__.__module__.split(".")[0] - yaml_spec = load_optional_package_file(package, "spec.yaml") - json_spec = load_optional_package_file(package, "spec.json") + yaml_spec = _load_optional_package_file(package, "spec.yaml") + json_spec = _load_optional_package_file(package, "spec.json") if yaml_spec and json_spec: raise RuntimeError( @@ -95,7 +116,7 @@ def spec(self, logger: logging.Logger) -> ConnectorSpecification: else: raise FileNotFoundError("Unable to find spec.yaml or spec.json in the package.") - return ConnectorSpecificationSerializer.load(spec_obj) + return ConnectorSpecification.from_dict(spec_obj) @abstractmethod def check(self, logger: logging.Logger, config: TConfig) -> AirbyteConnectionStatus: @@ -104,20 +125,60 @@ def check(self, logger: logging.Logger, config: TConfig) -> AirbyteConnectionSta to the Stripe API. """ - -class _WriteConfigProtocol(Protocol): - @staticmethod - def write_config(config: Mapping[str, Any], config_path: str) -> None: ... - - -class DefaultConnectorMixin: - # can be overridden to change an input config - def configure( - self: _WriteConfigProtocol, config: Mapping[str, Any], temp_dir: str - ) -> Mapping[str, Any]: - config_path = os.path.join(temp_dir, "config.json") - self.write_config(config, config_path) - return config - - -class Connector(DefaultConnectorMixin, BaseConnector[Mapping[str, Any]], ABC): ... + @classmethod + def create_with_cli_args( + cls, + cli_args: ConnectorCLIArgs, + ) -> Self: + """Return an instance of the connector, using the provided CLI args.""" + ... + + @classmethod + def launch_with_cli_args( + cls, + args: list[str], + *, + logger: logging.Logger | None = None, + message_repository: MessageRepository | None = None, + # TODO: Add support for inputs: + # stdin: StringIO | MessageRepository | None = None, + ) -> None: + """Launches the connector with the provided configuration.""" + logger = logger or logging.getLogger(f"airbyte.{type(cls).__name__}") + message_repository = message_repository or PassthroughMessageRepository() + parsed_cli_args: ConnectorCLIArgs = parse_cli_args( + args, + with_read=True if getattr(cls, "read", False) else False, + with_write=True if getattr(cls, "write", False) else False, + with_discover=True if getattr(cls, "discover", False) else False, + ) + logger.info(f"Launching connector with args: {parsed_cli_args}") + verb = parsed_cli_args.command + + spec: ConnectorSpecification + if verb == "check": + config = cls.to_typed_config(parsed_cli_args.get_config_dict()) + connector = cls.create_with_cli_args(parsed_cli_args) + connector.check(logger, config) + elif verb == "spec": + connector = cls() + spec = connector.spec(logger) + message_repository.emit_message( + AirbyteMessage( + type=Type.SPEC, + spec=spec, + ) + ) + elif verb == "discover": + connector = cls() + spec = connector.spec(logger) + print(json.dumps(spec.to_dict(), indent=2)) + elif verb == "read": + # Implementation for reading data goes here + pass + elif verb == "write": + # Implementation for writing data goes here + pass + else: + raise ValueError(f"Unknown command: {verb}") + # Implementation for launching the connector goes here diff --git a/airbyte_cdk/connector_builder/main.py b/airbyte_cdk/connector_builder/main.py index 80cf4afa9..69a02a31f 100644 --- a/airbyte_cdk/connector_builder/main.py +++ b/airbyte_cdk/connector_builder/main.py @@ -27,6 +27,7 @@ ) from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource from airbyte_cdk.sources.source import Source +from airbyte_cdk.utils.cli_arg_parse import parse_cli_args from airbyte_cdk.utils.traced_exception import AirbyteTracedException @@ -35,7 +36,7 @@ def get_config_and_catalog_from_args( ) -> Tuple[str, Mapping[str, Any], Optional[ConfiguredAirbyteCatalog], Any]: # TODO: Add functionality for the `debug` logger. # Currently, no one `debug` level log will be displayed during `read` a stream for a connector created through `connector-builder`. - parsed_args = AirbyteEntrypoint.parse_args(args) + parsed_args = parse_cli_args(args) config_path, catalog_path, state_path = ( parsed_args.config, parsed_args.catalog, diff --git a/airbyte_cdk/connector_builder/test_reader/reader.py b/airbyte_cdk/connector_builder/test_reader/reader.py index ea6e960c2..d0eef1c06 100644 --- a/airbyte_cdk/connector_builder/test_reader/reader.py +++ b/airbyte_cdk/connector_builder/test_reader/reader.py @@ -399,9 +399,7 @@ def _read_stream( # the generator can raise an exception # iterate over the generated messages. if next raise an exception, catch it and yield it as an AirbyteLogMessage try: - yield from AirbyteEntrypoint(source).read( - source.spec(self.logger), config, configured_catalog, state - ) + yield from source.read(source.spec(self.logger), config, configured_catalog, state) except AirbyteTracedException as traced_exception: # Look for this message which indicates that it is the "final exception" raised by AbstractSource. # If it matches, don't yield this as we don't need to show this in the Builder. diff --git a/airbyte_cdk/destinations/__init__.py b/airbyte_cdk/destinations/__init__.py index 3a641025b..c697d89ce 100644 --- a/airbyte_cdk/destinations/__init__.py +++ b/airbyte_cdk/destinations/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) 2021 Airbyte, Inc., all rights reserved. """The destinations module provides classes for building destination connectors.""" -from .destination import Destination +from airbyte_cdk.destinations.destination import Destination __all__ = [ "Destination", diff --git a/airbyte_cdk/destinations/destination.py b/airbyte_cdk/destinations/destination.py index 547f96684..9548be9d7 100644 --- a/airbyte_cdk/destinations/destination.py +++ b/airbyte_cdk/destinations/destination.py @@ -11,7 +11,7 @@ import orjson -from airbyte_cdk.connector import Connector +from airbyte_cdk.connector import BaseConnector from airbyte_cdk.exception_handler import init_uncaught_exception_handler from airbyte_cdk.models import ( AirbyteMessage, @@ -21,12 +21,13 @@ Type, ) from airbyte_cdk.sources.utils.schema_helpers import check_config_against_spec_or_exit +from airbyte_cdk.utils.cli_arg_parse import ConnectorCLIArgs, parse_cli_args from airbyte_cdk.utils.traced_exception import AirbyteTracedException logger = logging.getLogger("airbyte") -class Destination(Connector, ABC): +class Destination(BaseConnector, ABC): VALID_CMDS = {"spec", "check", "write"} @abstractmethod @@ -68,54 +69,10 @@ def _run_write( ) logger.info("Writing complete.") - def parse_args(self, args: List[str]) -> argparse.Namespace: - """ - :param args: commandline arguments - :return: - """ - - parent_parser = argparse.ArgumentParser(add_help=False) - main_parser = argparse.ArgumentParser() - subparsers = main_parser.add_subparsers(title="commands", dest="command") - - # spec - subparsers.add_parser( - "spec", help="outputs the json configuration specification", parents=[parent_parser] - ) - - # check - check_parser = subparsers.add_parser( - "check", help="checks the config can be used to connect", parents=[parent_parser] - ) - required_check_parser = check_parser.add_argument_group("required named arguments") - required_check_parser.add_argument( - "--config", type=str, required=True, help="path to the json configuration file" - ) - - # write - write_parser = subparsers.add_parser( - "write", help="Writes data to the destination", parents=[parent_parser] - ) - write_required = write_parser.add_argument_group("required named arguments") - write_required.add_argument( - "--config", type=str, required=True, help="path to the JSON configuration file" - ) - write_required.add_argument( - "--catalog", type=str, required=True, help="path to the configured catalog JSON file" - ) - - parsed_args = main_parser.parse_args(args) - cmd = parsed_args.command - if not cmd: - raise Exception("No command entered. ") - elif cmd not in ["spec", "check", "write"]: - # This is technically dead code since parse_args() would fail if this was the case - # But it's non-obvious enough to warrant placing it here anyways - raise Exception(f"Unknown command entered: {cmd}") - - return parsed_args - - def run_cmd(self, parsed_args: argparse.Namespace) -> Iterable[AirbyteMessage]: + def run_cmd( + self, + parsed_args: ConnectorCLIArgs, + ) -> Iterable[AirbyteMessage]: cmd = parsed_args.command if cmd not in self.VALID_CMDS: raise Exception(f"Unrecognized command: {cmd}") @@ -138,6 +95,9 @@ def run_cmd(self, parsed_args: argparse.Namespace) -> Iterable[AirbyteMessage]: if cmd == "check": yield self._run_check(config=config) elif cmd == "write": + if not parsed_args.catalog: + raise ValueError("Catalog path is required for write command.") + # Wrap in UTF-8 to override any other input encodings wrapped_stdin = io.TextIOWrapper(sys.stdin.buffer, encoding="utf-8") yield from self._run_write( @@ -148,7 +108,13 @@ def run_cmd(self, parsed_args: argparse.Namespace) -> Iterable[AirbyteMessage]: def run(self, args: List[str]) -> None: init_uncaught_exception_handler(logger) - parsed_args = self.parse_args(args) + parsed_args = ConnectorCLIArgs.from_namespace( + parse_cli_args( + args, + with_write=True, + with_read=False, + ) + ) output_messages = self.run_cmd(parsed_args) for message in output_messages: print(orjson.dumps(AirbyteMessageSerializer.dump(message)).decode()) diff --git a/airbyte_cdk/entrypoint.py b/airbyte_cdk/entrypoint.py index 76a1be32e..261e740f0 100644 --- a/airbyte_cdk/entrypoint.py +++ b/airbyte_cdk/entrypoint.py @@ -7,7 +7,6 @@ import ipaddress import json import logging -import os.path import socket import sys import tempfile @@ -19,10 +18,11 @@ import orjson import requests from requests import PreparedRequest, Response, Session +from typing_extensions import deprecated from airbyte_cdk.connector import TConfig from airbyte_cdk.exception_handler import init_uncaught_exception_handler -from airbyte_cdk.logger import PRINT_BUFFER, init_logger +from airbyte_cdk.logger import init_logger from airbyte_cdk.models import ( AirbyteConnectionStatus, AirbyteMessage, @@ -33,13 +33,14 @@ Status, Type, ) -from airbyte_cdk.sources import Source from airbyte_cdk.sources.connector_state_manager import HashableStreamDescriptor +from airbyte_cdk.sources.source import Source from airbyte_cdk.sources.utils.schema_helpers import check_config_against_spec_or_exit, split_config # from airbyte_cdk.utils import PrintBuffer, is_cloud_environment, message_utils # add PrintBuffer back once fixed from airbyte_cdk.utils import is_cloud_environment, message_utils from airbyte_cdk.utils.airbyte_secrets_utils import get_secrets, update_secrets +from airbyte_cdk.utils.cli_arg_parse import ConnectorCLIArgs, parse_cli_args from airbyte_cdk.utils.constants import ENV_REQUEST_CACHE_PATH from airbyte_cdk.utils.traced_exception import AirbyteTracedException @@ -50,6 +51,9 @@ _HAS_LOGGED_FOR_SERIALIZATION_ERROR = False +@deprecated( + "The `AirbyteEntrypoint` class is deprecated. Use `source.launch_with_cli_args()` instead." +) class AirbyteEntrypoint(object): def __init__(self, source: Source): init_uncaught_exception_handler(logger) @@ -61,99 +65,7 @@ def __init__(self, source: Source): self.source = source self.logger = logging.getLogger(f"airbyte.{getattr(source, 'name', '')}") - @staticmethod - def parse_args(args: List[str]) -> argparse.Namespace: - # set up parent parsers - parent_parser = argparse.ArgumentParser(add_help=False) - parent_parser.add_argument( - "--debug", action="store_true", help="enables detailed debug logs related to the sync" - ) - main_parser = argparse.ArgumentParser() - subparsers = main_parser.add_subparsers(title="commands", dest="command") - - # spec - subparsers.add_parser( - "spec", help="outputs the json configuration specification", parents=[parent_parser] - ) - - # check - check_parser = subparsers.add_parser( - "check", help="checks the config can be used to connect", parents=[parent_parser] - ) - required_check_parser = check_parser.add_argument_group("required named arguments") - required_check_parser.add_argument( - "--config", type=str, required=True, help="path to the json configuration file" - ) - check_parser.add_argument( - "--manifest-path", - type=str, - required=False, - help="path to the YAML manifest file to inject into the config", - ) - check_parser.add_argument( - "--components-path", - type=str, - required=False, - help="path to the custom components file, if it exists", - ) - - # discover - discover_parser = subparsers.add_parser( - "discover", - help="outputs a catalog describing the source's schema", - parents=[parent_parser], - ) - required_discover_parser = discover_parser.add_argument_group("required named arguments") - required_discover_parser.add_argument( - "--config", type=str, required=True, help="path to the json configuration file" - ) - discover_parser.add_argument( - "--manifest-path", - type=str, - required=False, - help="path to the YAML manifest file to inject into the config", - ) - discover_parser.add_argument( - "--components-path", - type=str, - required=False, - help="path to the custom components file, if it exists", - ) - - # read - read_parser = subparsers.add_parser( - "read", help="reads the source and outputs messages to STDOUT", parents=[parent_parser] - ) - - read_parser.add_argument( - "--state", type=str, required=False, help="path to the json-encoded state file" - ) - required_read_parser = read_parser.add_argument_group("required named arguments") - required_read_parser.add_argument( - "--config", type=str, required=True, help="path to the json configuration file" - ) - required_read_parser.add_argument( - "--catalog", - type=str, - required=True, - help="path to the catalog used to determine which data to read", - ) - read_parser.add_argument( - "--manifest-path", - type=str, - required=False, - help="path to the YAML manifest file to inject into the config", - ) - read_parser.add_argument( - "--components-path", - type=str, - required=False, - help="path to the custom components file, if it exists", - ) - - return main_parser.parse_args(args) - - def run(self, parsed_args: argparse.Namespace) -> Iterable[str]: + def run(self, parsed_args: argparse.Namespace | ConnectorCLIArgs) -> Iterable[str]: cmd = parsed_args.command if not cmd: raise Exception("No command passed") @@ -343,21 +255,21 @@ def airbyte_message_to_string(airbyte_message: AirbyteMessage) -> str: @classmethod def extract_state(cls, args: List[str]) -> Optional[Any]: - parsed_args = cls.parse_args(args) + parsed_args = parse_cli_args(args) if hasattr(parsed_args, "state"): return parsed_args.state return None @classmethod def extract_catalog(cls, args: List[str]) -> Optional[Any]: - parsed_args = cls.parse_args(args) + parsed_args = parse_cli_args(args) if hasattr(parsed_args, "catalog"): return parsed_args.catalog return None @classmethod def extract_config(cls, args: List[str]) -> Optional[Any]: - parsed_args = cls.parse_args(args) + parsed_args: ConnectorCLIArgs = parse_cli_args(args) if hasattr(parsed_args, "config"): return parsed_args.config return None @@ -368,16 +280,10 @@ def _emit_queued_messages(self, source: Source) -> Iterable[AirbyteMessage]: return +@deprecated("The `launch()` method is deprecated. Use `source.launch_with_cli_args()` instead.") def launch(source: Source, args: List[str]) -> None: - source_entrypoint = AirbyteEntrypoint(source) - parsed_args = source_entrypoint.parse_args(args) - # temporarily removes the PrintBuffer because we're seeing weird print behavior for concurrent syncs - # Refer to: https://github.com/airbytehq/oncall/issues/6235 - with PRINT_BUFFER: - for message in source_entrypoint.run(parsed_args): - # simply printing is creating issues for concurrent CDK as Python uses different two instructions to print: one for the message and - # the other for the break line. Adding `\n` to the message ensure that both are printed at the same time - print(f"{message}\n", end="") + """Deprecated.""" + source.launch_with_cli_args(args) def _init_internal_request_filter() -> None: @@ -447,4 +353,4 @@ def main() -> None: if not isinstance(source, Source): raise Exception("Source implementation provided does not implement Source class!") - launch(source, sys.argv[1:]) + source.launch_with_cli_args(sys.argv[1:]) diff --git a/airbyte_cdk/models/airbyte_protocol.py b/airbyte_cdk/models/airbyte_protocol.py index 5c5624428..594ea9bc9 100644 --- a/airbyte_cdk/models/airbyte_protocol.py +++ b/airbyte_cdk/models/airbyte_protocol.py @@ -1,14 +1,71 @@ -# -# Copyright (c) 2023 Airbyte, Inc., all rights reserved. -# +# Copyright (c) 2025 Airbyte, Inc., all rights reserved. +from __future__ import annotations + +import json from dataclasses import InitVar, dataclass -from typing import Annotated, Any, Dict, List, Mapping, Optional, Union +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Dict, + List, + Optional, + Type, + TypeVar, +) -from airbyte_protocol_dataclasses.models import * # noqa: F403 # Allow '*' +import orjson +from airbyte_protocol_dataclasses.models import ( + AdvancedAuth, + AirbyteAnalyticsTraceMessage, + AirbyteCatalog, + AirbyteConnectionStatus, + AirbyteControlConnectorConfigMessage, + AirbyteControlMessage, + AirbyteErrorTraceMessage, + AirbyteEstimateTraceMessage, + AirbyteGlobalState, + AirbyteLogMessage, + AirbyteMessage, + AirbyteProtocol, + AirbyteRecordMessage, + AirbyteRecordMessageFileReference, + AirbyteStateBlob, + AirbyteStateMessage, + AirbyteStateStats, + AirbyteStateType, + AirbyteStream, + AirbyteStreamState, + AirbyteStreamStatus, + AirbyteStreamStatusReason, + AirbyteStreamStatusReasonType, + AirbyteStreamStatusTraceMessage, + AirbyteTraceMessage, + AuthFlowType, + ConfiguredAirbyteCatalog, + ConfiguredAirbyteStream, + ConnectorSpecification, + DestinationSyncMode, + EstimateType, + FailureType, + Level, + OAuthConfigSpecification, + OauthConnectorInputSpecification, + OrchestratorType, + State, + Status, + StreamDescriptor, + SyncMode, + TraceType, + Type, +) +from boltons.typeutils import classproperty +from serpyco_rs import CustomType, Serializer from serpyco_rs.metadata import Alias -# ruff: noqa: F405 # ignore fuzzy import issues with 'import *' +if TYPE_CHECKING: + from collections.abc import Callable, Mapping @dataclass @@ -50,22 +107,113 @@ def __eq__(self, other: object) -> bool: ) +T = TypeVar("T", bound="SerDeMixin") + + +def _custom_state_resolver(t: type) -> CustomType[AirbyteStateBlob, dict[str, Any]] | None: + class AirbyteStateBlobType(CustomType[AirbyteStateBlob, dict[str, Any]]): + def serialize(self, value: AirbyteStateBlob) -> dict[str, Any]: + # cant use orjson.dumps() directly because private attributes are excluded, e.g. "__ab_full_refresh_sync_complete" + return {k: v for k, v in value.__dict__.items()} + + def deserialize(self, value: dict[str, Any]) -> AirbyteStateBlob: + return AirbyteStateBlob(value) + + def get_json_schema(self) -> dict[str, Any]: + return {"type": "object"} + + return AirbyteStateBlobType() if t is AirbyteStateBlob else None + + +class SerDeMixin: + _serializer: Serializer[Any] + + def to_dict(self) -> dict[str, Any]: + """Serialize the object to a dictionary. + + This method uses the `Serializer` to serialize the object to a dict as quickly as possible. + """ + return self._serializer.dump(self) + + def to_json(self) -> str: + """Serialize the object to JSON. + + This method uses `orjson` to serialize the object to JSON as quickly as possible. + """ + return orjson.dumps(self.to_dict()).decode("utf-8") + + def __str__(self) -> str: + """Casting to `str` is the same as casting to JSON. + + These are equivalent: + >>> msg = AirbyteMessage(...) + >>> str(msg) + >>> msg.to_json() + """ + return self.to_json() + + @classmethod + def from_dict(cls: type[T], data: dict[str, Any], /) -> T: + return cls._serializer.load(data) + + @classmethod + def from_json(cls: type[T], str_value: str, /) -> T: + """Load the object from JSON. + + This method first tries to deserialize the JSON string using `orjson.loads()`, + falling back to `json.loads()` if it fails. This is because `orjson` does not support + all JSON features, such as `NaN` and `Infinity`, which are supported by the standard + `json` module. The `orjson` library is used for its speed and efficiency, while the + standard `json` library is used as a fallback for compatibility with more complex JSON + structures. + + Raises: + orjson.JSONDecodeError: If the JSON string cannot be deserialized by either + `orjson` or `json`. + """ + try: + dict_value = orjson.loads(str_value) + except orjson.JSONDecodeError as orjson_error: + try: + dict_value = json.loads(str_value) + except json.JSONDecodeError as json_error: + # Callers will expect `orjson.JSONDecodeError`, so we raise the original + # `orjson` error when both options fail. + # We also attach the second error, in case it is useful for debugging. + raise orjson_error from json_error + + return cls.from_dict(dict_value) + + # The following dataclasses have been redeclared to include the new version of AirbyteStateBlob @dataclass -class AirbyteStreamState: +class AirbyteStreamState(AirbyteStreamState, SerDeMixin): stream_descriptor: StreamDescriptor # type: ignore [name-defined] stream_state: Optional[AirbyteStateBlob] = None +AirbyteStreamState._serializer = Serializer( + AirbyteStreamState, + omit_none=True, + custom_type_resolver=_custom_state_resolver, +) + + @dataclass -class AirbyteGlobalState: +class AirbyteGlobalState(SerDeMixin): stream_states: List[AirbyteStreamState] shared_state: Optional[AirbyteStateBlob] = None +AirbyteGlobalState._serializer = Serializer( + AirbyteGlobalState, + omit_none=True, + custom_type_resolver=_custom_state_resolver, +) + @dataclass -class AirbyteStateMessage: - type: Optional[AirbyteStateType] = None # type: ignore [name-defined] +class AirbyteStateMessage(SerDeMixin): + type: AirbyteStateType | None = None # type: ignore [name-defined] stream: Optional[AirbyteStreamState] = None global_: Annotated[AirbyteGlobalState | None, Alias("global")] = ( None # "global" is a reserved keyword in python ⇒ Alias is used for (de-)serialization @@ -75,8 +223,15 @@ class AirbyteStateMessage: destinationStats: Optional[AirbyteStateStats] = None # type: ignore [name-defined] +AirbyteStateMessage._serializer = Serializer( + AirbyteStateMessage, + omit_none=True, + custom_type_resolver=_custom_state_resolver, +) + + @dataclass -class AirbyteMessage: +class AirbyteMessage(SerDeMixin): type: Type # type: ignore [name-defined] log: Optional[AirbyteLogMessage] = None # type: ignore [name-defined] spec: Optional[ConnectorSpecification] = None # type: ignore [name-defined] @@ -86,3 +241,50 @@ class AirbyteMessage: state: Optional[AirbyteStateMessage] = None trace: Optional[AirbyteTraceMessage] = None # type: ignore [name-defined] control: Optional[AirbyteControlMessage] = None # type: ignore [name-defined] + + +AirbyteMessage._serializer = Serializer( + AirbyteMessage, + omit_none=True, + custom_type_resolver=_custom_state_resolver, +) + + +class ConfiguredAirbyteCatalog(ConfiguredAirbyteCatalog, SerDeMixin): + pass + + +ConfiguredAirbyteCatalog._serializer = Serializer( + ConfiguredAirbyteCatalog, + omit_none=True, + custom_type_resolver=_custom_state_resolver, +) + + +class ConfiguredAirbyteStream(ConfiguredAirbyteStream, SerDeMixin): + pass + + +ConfiguredAirbyteStream._serializer = Serializer( + ConfiguredAirbyteStream, + omit_none=True, + custom_type_resolver=_custom_state_resolver, +) + + +class ConnectorSpecification(ConnectorSpecification, SerDeMixin): + pass + + +ConnectorSpecification._serializer = Serializer( + ConnectorSpecification, + omit_none=True, +) + +# Deprecated Serializer Classes. Declared here for legacy compatibility: +AirbyteStreamStateSerializer = AirbyteStreamState._serializer # type: ignore +AirbyteStateMessageSerializer = AirbyteStateMessage._serializer # type: ignore +AirbyteMessageSerializer = AirbyteMessage._serializer # type: ignore +ConfiguredAirbyteCatalogSerializer = ConfiguredAirbyteCatalog._serializer # type: ignore +ConfiguredAirbyteStreamSerializer = ConfiguredAirbyteStream._serializer # type: ignore +ConnectorSpecificationSerializer = ConnectorSpecification._serializer # type: ignore diff --git a/airbyte_cdk/models/airbyte_protocol_serializers.py b/airbyte_cdk/models/airbyte_protocol_serializers.py index 129556acc..70ee424e3 100644 --- a/airbyte_cdk/models/airbyte_protocol_serializers.py +++ b/airbyte_cdk/models/airbyte_protocol_serializers.py @@ -1,44 +1,14 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. -from typing import Any, Dict +"""This module is deprecated and exists only for legacy compatibility. -from serpyco_rs import CustomType, Serializer +Instead of importing from this module, callers should import from +`airbyte_cdk.models.airbyte_protocol` directly. -from .airbyte_protocol import ( # type: ignore[attr-defined] # all classes are imported to airbyte_protocol via * - AirbyteMessage, - AirbyteStateBlob, - AirbyteStateMessage, - AirbyteStreamState, - ConfiguredAirbyteCatalog, - ConfiguredAirbyteStream, - ConnectorSpecification, -) +The dedicated SerDes classes are _also_ deprecated. Instead, use these methods: +- `from_dict()` +- `from_json()` +- `to_dict()` +- `to_json()` +""" - -class AirbyteStateBlobType(CustomType[AirbyteStateBlob, Dict[str, Any]]): - def serialize(self, value: AirbyteStateBlob) -> Dict[str, Any]: - # cant use orjson.dumps() directly because private attributes are excluded, e.g. "__ab_full_refresh_sync_complete" - return {k: v for k, v in value.__dict__.items()} - - def deserialize(self, value: Dict[str, Any]) -> AirbyteStateBlob: - return AirbyteStateBlob(value) - - def get_json_schema(self) -> Dict[str, Any]: - return {"type": "object"} - - -def custom_type_resolver(t: type) -> CustomType[AirbyteStateBlob, Dict[str, Any]] | None: - return AirbyteStateBlobType() if t is AirbyteStateBlob else None - - -AirbyteStreamStateSerializer = Serializer( - AirbyteStreamState, omit_none=True, custom_type_resolver=custom_type_resolver -) -AirbyteStateMessageSerializer = Serializer( - AirbyteStateMessage, omit_none=True, custom_type_resolver=custom_type_resolver -) -AirbyteMessageSerializer = Serializer( - AirbyteMessage, omit_none=True, custom_type_resolver=custom_type_resolver -) -ConfiguredAirbyteCatalogSerializer = Serializer(ConfiguredAirbyteCatalog, omit_none=True) -ConfiguredAirbyteStreamSerializer = Serializer(ConfiguredAirbyteStream, omit_none=True) -ConnectorSpecificationSerializer = Serializer(ConnectorSpecification, omit_none=True) +from airbyte_cdk.models.airbyte_protocol import * # type: ignore[attr-defined] diff --git a/airbyte_cdk/sources/__init__.py b/airbyte_cdk/sources/__init__.py index a6560a503..12e510bbd 100644 --- a/airbyte_cdk/sources/__init__.py +++ b/airbyte_cdk/sources/__init__.py @@ -4,9 +4,7 @@ import dpath.options -from .abstract_source import AbstractSource from .config import BaseConfig -from .source import Source # As part of the CDK sources, we do not control what the APIs return and it is possible that a key is empty. # Reasons why we are doing this at the airbyte_cdk level: @@ -20,7 +18,5 @@ dpath.options.ALLOW_EMPTY_STRING_KEYS = True __all__ = [ - "AbstractSource", "BaseConfig", - "Source", ] diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py b/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py index c150dc956..be29cf8d4 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py @@ -8,7 +8,7 @@ from typing import Any, Callable, Iterator, List, Mapping, MutableMapping, Optional, Tuple from airbyte_cdk.models import AirbyteMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog -from airbyte_cdk.sources import AbstractSource +from airbyte_cdk.sources.abstract_source import AbstractSource from airbyte_cdk.sources.concurrent_source.concurrent_source import ConcurrentSource from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.streams import Stream diff --git a/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py b/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py index 12e1740b6..0b62da6b1 100644 --- a/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py +++ b/airbyte_cdk/sources/file_based/availability_strategy/abstract_file_based_availability_strategy.py @@ -8,7 +8,6 @@ from abc import abstractmethod from typing import TYPE_CHECKING, Optional, Tuple -from airbyte_cdk.sources import Source from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy from airbyte_cdk.sources.streams.concurrent.availability_strategy import ( AbstractAvailabilityStrategy, @@ -20,6 +19,7 @@ if TYPE_CHECKING: from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream + from airbyte_cdk.sources.source import Source class AbstractFileBasedAvailabilityStrategy(AvailabilityStrategy): diff --git a/airbyte_cdk/sources/file_based/availability_strategy/default_file_based_availability_strategy.py b/airbyte_cdk/sources/file_based/availability_strategy/default_file_based_availability_strategy.py index c9d416a72..0a724a2de 100644 --- a/airbyte_cdk/sources/file_based/availability_strategy/default_file_based_availability_strategy.py +++ b/airbyte_cdk/sources/file_based/availability_strategy/default_file_based_availability_strategy.py @@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Optional, Tuple from airbyte_cdk import AirbyteTracedException -from airbyte_cdk.sources import Source from airbyte_cdk.sources.file_based.availability_strategy import ( AbstractFileBasedAvailabilityStrategy, ) @@ -18,12 +17,15 @@ CustomFileBasedException, FileBasedSourceError, ) -from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader -from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.schema_helpers import conforms_to_schema if TYPE_CHECKING: + from airbyte_cdk.sources.file_based.file_based_stream_reader import ( + AbstractFileBasedStreamReader, + ) + from airbyte_cdk.sources.file_based.remote_file import RemoteFile from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream + from airbyte_cdk.sources.source import Source class DefaultFileBasedAvailabilityStrategy(AbstractFileBasedAvailabilityStrategy): diff --git a/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py b/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py index c36e5179d..ecaa63da3 100644 --- a/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py +++ b/airbyte_cdk/sources/file_based/stream/concurrent/adapters.py @@ -17,7 +17,7 @@ SyncMode, Type, ) -from airbyte_cdk.sources import AbstractSource +from airbyte_cdk.sources.abstract_source import AbstractSource from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.file_based.availability_strategy import ( AbstractFileBasedAvailabilityStrategy, diff --git a/airbyte_cdk/sources/message/repository.py b/airbyte_cdk/sources/message/repository.py index 2fc156e8c..cdcf0eeee 100644 --- a/airbyte_cdk/sources/message/repository.py +++ b/airbyte_cdk/sources/message/repository.py @@ -4,6 +4,7 @@ import json import logging +import sys from abc import ABC, abstractmethod from collections import deque from typing import Callable, Deque, Iterable, List, Optional @@ -60,6 +61,43 @@ def consume_queue(self) -> Iterable[AirbyteMessage]: raise NotImplementedError() +class PassthroughMessageRepository(MessageRepository): + """A message repository which simply passes output on to STDOUT.""" + + def __init__( + self, + log_level: Level = Level.WARN, + ) -> None: + """Initialize the message repository. + + Log level is configurable. + """ + self._log_level: Level = log_level + + def emit_message(self, message: AirbyteMessage) -> None: + """Passthrough message to STDOUT.""" + sys.stdout.write(message.to_string()) + + def log_message( + self, + level: Level, + message_provider: Callable[[], LogMessage], + ) -> None: + if _is_severe_enough(self._log_level, level): + self.emit_message( + AirbyteMessage( + type=Type.LOG, + log=AirbyteLogMessage( + level=level, message=filter_secrets(json.dumps(message_provider())) + ), + ) + ) + + def consume_queue(self) -> Iterable[AirbyteMessage]: + """No-op, since nothing is queued in the passthrough message repository.""" + return [] + + class NoopMessageRepository(MessageRepository): def emit_message(self, message: AirbyteMessage) -> None: pass @@ -72,6 +110,7 @@ def consume_queue(self) -> Iterable[AirbyteMessage]: class InMemoryMessageRepository(MessageRepository): + def __init__(self, log_level: Level = Level.INFO) -> None: self._message_queue: Deque[AirbyteMessage] = deque() self._log_level = log_level diff --git a/airbyte_cdk/sources/source.py b/airbyte_cdk/sources/source.py index 2958d82ca..aea85290f 100644 --- a/airbyte_cdk/sources/source.py +++ b/airbyte_cdk/sources/source.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from typing import Any, Generic, Iterable, List, Mapping, Optional, TypeVar -from airbyte_cdk.connector import BaseConnector, DefaultConnectorMixin, TConfig +from airbyte_cdk.connector import BaseConnector, TConfig from airbyte_cdk.models import ( AirbyteCatalog, AirbyteMessage, @@ -53,7 +53,6 @@ def discover(self, logger: logging.Logger, config: TConfig) -> AirbyteCatalog: class Source( - DefaultConnectorMixin, BaseSource[Mapping[str, Any], List[AirbyteStateMessage], ConfiguredAirbyteCatalog], ABC, ): diff --git a/airbyte_cdk/sources/streams/concurrent/adapters.py b/airbyte_cdk/sources/streams/concurrent/adapters.py index 7da594155..a2539b42c 100644 --- a/airbyte_cdk/sources/streams/concurrent/adapters.py +++ b/airbyte_cdk/sources/streams/concurrent/adapters.py @@ -19,7 +19,7 @@ SyncMode, Type, ) -from airbyte_cdk.sources import AbstractSource, Source +from airbyte_cdk.sources.abstract_source import AbstractSource, Source from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.message import MessageRepository from airbyte_cdk.sources.source import ExperimentalClassWarning diff --git a/airbyte_cdk/test/entrypoint_wrapper.py b/airbyte_cdk/test/entrypoint_wrapper.py index 79c328203..5eff950ef 100644 --- a/airbyte_cdk/test/entrypoint_wrapper.py +++ b/airbyte_cdk/test/entrypoint_wrapper.py @@ -19,15 +19,18 @@ import re import tempfile import traceback +from collections.abc import Callable from io import StringIO from pathlib import Path from typing import Any, List, Mapping, Optional, Union import orjson -from pydantic import ValidationError as V2ValidationError +from requests_cache import Iterable from serpyco_rs import SchemaValidationError +from typing_extensions import deprecated +from ulid import T -from airbyte_cdk.entrypoint import AirbyteEntrypoint +from airbyte_cdk.connector_builder.models import LogMessage from airbyte_cdk.exception_handler import assemble_uncaught_exception from airbyte_cdk.logger import AirbyteLogFormatter from airbyte_cdk.models import ( @@ -43,27 +46,66 @@ TraceType, Type, ) -from airbyte_cdk.sources import Source +from airbyte_cdk.models.airbyte_protocol import AirbyteMessage +from airbyte_cdk.sources.message.repository import ( + InMemoryMessageRepository, + MessageRepository, + _is_severe_enough, +) +from airbyte_cdk.sources.source import Source +from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets +from airbyte_cdk.utils.cli_arg_parse import ConnectorCLIArgs, parse_cli_args -class EntrypointOutput: - def __init__(self, messages: List[str], uncaught_exception: Optional[BaseException] = None): - try: - self._messages = [self._parse_message(message) for message in messages] - except V2ValidationError as exception: - raise ValueError("All messages are expected to be AirbyteMessage") from exception +class TestOutputMessageRepository(MessageRepository): + """An implementation of MessageRepository used for testing. - if uncaught_exception: - self._messages.append( - assemble_uncaught_exception( - type(uncaught_exception), uncaught_exception - ).as_airbyte_message() + It captures both the messages emitted by the source and the logs printed to stdout. + + This class replaces `EntrypointOutput`. + + Warning: OOM errors may occur if the source generates a large number of messages. + TODO: Optimize this to switch to a disk-side buffer if available memory is at risk of being + overrun. + """ + + def __init__(self, log_level: Level = Level.INFO) -> None: + self._log_level = log_level + self._messages: list[AirbyteMessage] = [] + self._ignored_logs: list[LogMessage] = [] + self._consumed_to_marker = 0 + + def emit_message(self, message: AirbyteMessage) -> None: + self._messages.append(message) + + def log_message(self, level: Level, message_provider: Callable[[], LogMessage]) -> None: + if _is_severe_enough(self._log_level, level): + self.emit_message( + AirbyteMessage( + type=Type.LOG, + log=AirbyteLogMessage( + level=level, message=filter_secrets(json.dumps(message_provider())) + ), + ) ) + else: + self._ignored_logs.append(message_provider()) + + def consume_queue(self) -> Iterable[AirbyteMessage]: + """Consume the message queue and return all messages. + + This method primarily exists to support the `MessageRepository` interface. + Note: Callers can more easily consume the queue by reading from `messages` directly. + + To avoid race conditions, we first get the high-water mark and then return to that point. + """ + self._consumed_to_marker = len(self._messages) + return self._messages[: self._consumed_to_marker] @staticmethod def _parse_message(message: str) -> AirbyteMessage: try: - return AirbyteMessageSerializer.load(orjson.loads(message)) + return AirbyteMessage.from_json(message) except (orjson.JSONDecodeError, SchemaValidationError): # The platform assumes that logs that are not of AirbyteMessage format are log messages return AirbyteMessage( @@ -130,7 +172,8 @@ def get_stream_statuses(self, stream_name: str) -> List[AirbyteStreamStatus]: ) return list(status_messages) - def _get_message_by_types(self, message_types: List[Type]) -> List[AirbyteMessage]: + def _get_message_by_types(self, message_types: list[Type]) -> list[AirbyteMessage]: + """Return all messages of the given types.""" return [message for message in self._messages if message.type in message_types] def _get_trace_message_by_trace_type(self, trace_type: TraceType) -> List[AirbyteMessage]: @@ -156,9 +199,38 @@ def is_not_in_logs(self, pattern: str) -> bool: return not self.is_in_logs(pattern) +@deprecated("Please use `TestOutputMessageRepository` instead.") +class EntrypointOutput(TestOutputMessageRepository): + """A class that captures the output of the entrypoint. + + It captures both the messages emitted by the source and the logs printed to stdout. + """ + + def __init__( + self, + messages: list[str] | None = None, + uncaught_exception: BaseException | None = None, + ) -> None: + super().__init__() + for msg in messages or []: + self.emit_message(self._parse_message(msg)) + + self._uncaught_exception = uncaught_exception + if uncaught_exception: + self.emit_message( + assemble_uncaught_exception( + type(uncaught_exception), uncaught_exception + ).as_airbyte_message() + ) + + @property + def uncaught_exception(self) -> BaseException | None: + return self._uncaught_exception + + def _run_command( - source: Source, args: List[str], expecting_exception: bool = False -) -> EntrypointOutput: + source: Source, args: list[str], expecting_exception: bool = False +) -> TestOutputMessageRepository: log_capture_buffer = StringIO() stream_handler = logging.StreamHandler(log_capture_buffer) stream_handler.setLevel(logging.INFO) @@ -166,32 +238,29 @@ def _run_command( parent_logger = logging.getLogger("") parent_logger.addHandler(stream_handler) - parsed_args = AirbyteEntrypoint.parse_args(args) - - source_entrypoint = AirbyteEntrypoint(source) - messages = [] - uncaught_exception = None + message_repository = TestOutputMessageRepository() try: - for message in source_entrypoint.run(parsed_args): - messages.append(message) + source.launch_with_cli_args( + args, + logger=parent_logger, + message_repository=message_repository, + ) except Exception as exception: if not expecting_exception: print("Printing unexpected error from entrypoint_wrapper") print("".join(traceback.format_exception(None, exception, exception.__traceback__))) uncaught_exception = exception - captured_logs = log_capture_buffer.getvalue().split("\n")[:-1] - parent_logger.removeHandler(stream_handler) - return EntrypointOutput(messages + captured_logs, uncaught_exception) + return message_repository def discover( source: Source, config: Mapping[str, Any], expecting_exception: bool = False, -) -> EntrypointOutput: +) -> TestOutputMessageRepository: """ config must be json serializable :param expecting_exception: By default if there is an uncaught exception, the exception will be printed out. If this is expected, please @@ -213,7 +282,7 @@ def read( catalog: ConfiguredAirbyteCatalog, state: Optional[List[AirbyteStateMessage]] = None, expecting_exception: bool = False, -) -> EntrypointOutput: +) -> TestOutputMessageRepository: """ config and state must be json serializable diff --git a/airbyte_cdk/test/standard_tests/_job_runner.py b/airbyte_cdk/test/standard_tests/_job_runner.py index ad8316d78..aa2fbd58b 100644 --- a/airbyte_cdk/test/standard_tests/_job_runner.py +++ b/airbyte_cdk/test/standard_tests/_job_runner.py @@ -22,7 +22,7 @@ def _errors_to_str( - entrypoint_output: entrypoint_wrapper.EntrypointOutput, + entrypoint_output: entrypoint_wrapper.TestOutputMessageRepository, ) -> str: """Convert errors from entrypoint output to a string.""" if not entrypoint_output.errors: @@ -60,7 +60,7 @@ def run_test_job( *, test_scenario: ConnectorTestScenario | None = None, catalog: ConfiguredAirbyteCatalog | dict[str, Any] | None = None, -) -> entrypoint_wrapper.EntrypointOutput: +) -> entrypoint_wrapper.TestOutputMessageRepository: """Run a test scenario from provided CLI args and return the result.""" # Use default (empty) scenario if not provided: test_scenario = test_scenario or ConnectorTestScenario() @@ -115,7 +115,7 @@ def run_test_job( # This is a bit of a hack because the source needs the catalog early. # Because it *also* can fail, we have to redundantly wrap it in a try/except block. - result: entrypoint_wrapper.EntrypointOutput = entrypoint_wrapper._run_command( # noqa: SLF001 # Non-public API + result: entrypoint_wrapper.TestOutputMessageRepository = entrypoint_wrapper._run_command( # noqa: SLF001 # Non-public API source=connector_obj, # type: ignore [arg-type] args=args, expecting_exception=test_scenario.expect_exception, diff --git a/airbyte_cdk/test/standard_tests/connector_base.py b/airbyte_cdk/test/standard_tests/connector_base.py index 394028247..4c815b857 100644 --- a/airbyte_cdk/test/standard_tests/connector_base.py +++ b/airbyte_cdk/test/standard_tests/connector_base.py @@ -112,7 +112,7 @@ def test_check( scenario: ConnectorTestScenario, ) -> None: """Run `connection` acceptance tests.""" - result: entrypoint_wrapper.EntrypointOutput = run_test_job( + result: entrypoint_wrapper.TestOutputMessageRepository = run_test_job( self.create_connector(scenario), "check", test_scenario=scenario, diff --git a/airbyte_cdk/test/standard_tests/source_base.py b/airbyte_cdk/test/standard_tests/source_base.py index a256fa04c..a7caa1948 100644 --- a/airbyte_cdk/test/standard_tests/source_base.py +++ b/airbyte_cdk/test/standard_tests/source_base.py @@ -39,7 +39,7 @@ def test_check( This test is designed to validate the connector's ability to establish a connection and return its status with the expected message type. """ - result: entrypoint_wrapper.EntrypointOutput = run_test_job( + result: entrypoint_wrapper.TestOutputMessageRepository = run_test_job( self.create_connector(scenario), "check", test_scenario=scenario, @@ -154,7 +154,7 @@ def test_fail_read_with_bad_catalog( ) # Set expected status to "failed" to ensure the test fails if the connector. scenario.status = "failed" - result: entrypoint_wrapper.EntrypointOutput = run_test_job( + result: entrypoint_wrapper.TestOutputMessageRepository = run_test_job( self.create_connector(scenario), "read", test_scenario=scenario, diff --git a/airbyte_cdk/test/utils/reading.py b/airbyte_cdk/test/utils/reading.py index 2d89cb870..c13af823c 100644 --- a/airbyte_cdk/test/utils/reading.py +++ b/airbyte_cdk/test/utils/reading.py @@ -5,7 +5,7 @@ from airbyte_cdk import AbstractSource from airbyte_cdk.models import AirbyteStateMessage, ConfiguredAirbyteCatalog, SyncMode from airbyte_cdk.test.catalog_builder import CatalogBuilder -from airbyte_cdk.test.entrypoint_wrapper import EntrypointOutput, read +from airbyte_cdk.test.entrypoint_wrapper import TestOutputMessageRepository, read def catalog(stream_name: str, sync_mode: SyncMode) -> ConfiguredAirbyteCatalog: @@ -20,7 +20,7 @@ def read_records( sync_mode: SyncMode, state: Optional[List[AirbyteStateMessage]] = None, expecting_exception: bool = False, -) -> EntrypointOutput: +) -> TestOutputMessageRepository: """Read records from a stream.""" _catalog = catalog(stream_name, sync_mode) return read(source, config, _catalog, state, expecting_exception) diff --git a/airbyte_cdk/utils/cli_arg_parse.py b/airbyte_cdk/utils/cli_arg_parse.py new file mode 100644 index 000000000..7b3d85767 --- /dev/null +++ b/airbyte_cdk/utils/cli_arg_parse.py @@ -0,0 +1,206 @@ +# Copyright (c) 2025 Airbyte, Inc., all rights reserved. +"""CLI Argument Parsing Utilities.""" + +import argparse +import json +from collections.abc import MutableMapping +from dataclasses import dataclass +from pathlib import Path +from typing import Any + + +def _read_json_file( + file_path: Path | str, +) -> dict[str, Any]: + """Read a JSON file and return its contents as a dictionary. + + Raises ValueError if the file cannot be read or is not valid JSON. + """ + file_text = Path(file_path).read_text() + + try: + return json.loads(file_text) + except json.JSONDecodeError as error: + raise ValueError( + f"Could not read json file {file_path}: {error}. Please ensure that it is a valid JSON." + ) + + +def parse_cli_args( + args: list[str], + *, + with_read: bool = True, + with_discover: bool = True, + with_write: bool = False, +) -> argparse.Namespace: + """Return the parsed CLI arguments for the connector. + + The caller can validate the arguments and use them as needed. This function allows all possible + arguments to be passed in, but the caller should only use the ones that are relevant to the + command being executed. By default, we expect a typical "source" configuration. + + Optionally, caller may specify command availability by overriding the `with` flags. + """ + # set up parent parsers + parent_parser = argparse.ArgumentParser(add_help=False) + parent_parser.add_argument( + "--debug", action="store_true", help="enables detailed debug logs related to the sync" + ) + main_parser = argparse.ArgumentParser() + subparsers = main_parser.add_subparsers( + title="commands", + dest="command", + required=True, + ) + + # spec + subparsers.add_parser( + "spec", help="outputs the json configuration specification", parents=[parent_parser] + ) + + # check + check_parser = subparsers.add_parser( + "check", help="checks the config can be used to connect", parents=[parent_parser] + ) + required_check_parser = check_parser.add_argument_group("required named arguments") + required_check_parser.add_argument( + "--config", type=str, required=True, help="path to the json configuration file" + ) + check_parser.add_argument( + "--manifest-path", + type=str, + required=False, + help="path to the YAML manifest file to inject into the config", + ) + check_parser.add_argument( + "--components-path", + type=str, + required=False, + help="path to the custom components file, if it exists", + ) + + if with_discover: + discover_parser = subparsers.add_parser( + "discover", + help="outputs a catalog describing the source's schema", + parents=[parent_parser], + ) + required_discover_parser = discover_parser.add_argument_group("required named arguments") + required_discover_parser.add_argument( + "--config", type=str, required=True, help="path to the json configuration file" + ) + discover_parser.add_argument( + "--manifest-path", + type=str, + required=False, + help="path to the YAML manifest file to inject into the config", + ) + discover_parser.add_argument( + "--components-path", + type=str, + required=False, + help="path to the custom components file, if it exists", + ) + + if with_read: + read_parser = subparsers.add_parser( + "read", help="reads the source and outputs messages to STDOUT", parents=[parent_parser] + ) + + read_parser.add_argument( + "--state", type=str, required=False, help="path to the json-encoded state file" + ) + required_read_parser = read_parser.add_argument_group("required named arguments") + required_read_parser.add_argument( + "--config", type=str, required=True, help="path to the json configuration file" + ) + required_read_parser.add_argument( + "--catalog", + type=str, + required=True, + help="path to the catalog used to determine which data to read", + ) + read_parser.add_argument( + "--manifest-path", + type=str, + required=False, + help="path to the YAML manifest file to inject into the config", + ) + read_parser.add_argument( + "--components-path", + type=str, + required=False, + help="path to the custom components file, if it exists", + ) + + if with_write: + # write + write_parser = subparsers.add_parser( + "write", help="Writes data to the destination", parents=[parent_parser] + ) + write_required = write_parser.add_argument_group("required named arguments") + write_required.add_argument( + "--config", type=str, required=True, help="path to the JSON configuration file" + ) + write_required.add_argument( + "--catalog", type=str, required=True, help="path to the configured catalog JSON file" + ) + + return main_parser.parse_args(args) + + +@dataclass(kw_only=True) +class ConnectorCLIArgs: + """Strongly typed dataclass to hold CLI arguments for the connector. + + This class can be used as a type-safe alternative to argparse.Namespace. + """ + + command: str + debug: bool | None = None + config: str | None = None + state: str | None = None + catalog: str | None = None + manifest_path: str | None = None + components_path: str | None = None + + def get_config_dict( + self, + *, + allow_missing: bool = False, + ) -> MutableMapping[str, Any]: + """Read the config file and return its contents as a dictionary. + + If allow_missing is True, return an empty dictionary when the config file is not provided. + """ + if self.config is None: + if not allow_missing: + raise ValueError("Config file path is required.") + + return {} + + config = _read_json_file(self.config) + if isinstance(config, MutableMapping): + return config + else: + raise ValueError( + f"The content of {self.config} is not an object and therefore is not a valid config. Please ensure the file represent a config." + ) + + @classmethod + def from_namespace( + cls, + parsed_args: argparse.Namespace, + ) -> "ConnectorCLIArgs": + """Create a ConnectorCLIArgs instance from an argparse Namespace.""" + return cls( + command=parsed_args.command, + debug=parsed_args.debug if "debug" in parsed_args else None, + config=parsed_args.config if "config" in parsed_args else None, + state=parsed_args.state if "state" in parsed_args else None, + catalog=parsed_args.catalog if "catalog" in parsed_args else None, + manifest_path=parsed_args.manifest_path if "manifest_path" in parsed_args else None, + components_path=parsed_args.components_path + if "components_path" in parsed_args + else None, + ) diff --git a/cdk-migrations.md b/cdk-migrations.md index d07c184d2..130151b2e 100644 --- a/cdk-migrations.md +++ b/cdk-migrations.md @@ -1,5 +1,17 @@ # CDK Migration Guide +## Upgrading to 6.XX.YY + +This version deprecates (but does not remove) the AirbyteEntrypoint class and related methods. +Deprecation warnings will be emitted beginning in this version if `launch()` is called, or any +other now-deprecated methods or classes. + +Beginning in this version, all connectors have a `launch_with_cli_args()` class method, which +can be called directly from the class itself. + +Most connector classes will not need to modify this class method, although it can be overriden as +needed, if any custom behavior is required. + ## Upgrading to 6.34.0 [Version 6.34.0](https://github.com/airbytehq/airbyte-python-cdk/releases/tag/v6.34.0) of the CDK removes support for `stream_state` in the Jinja interpolation context. This change is breaking for any low-code connectors that use `stream_state` in the interpolation context. diff --git a/debug_manifest/debug_manifest.py b/debug_manifest/debug_manifest.py index c520d0b0c..ceb74f046 100644 --- a/debug_manifest/debug_manifest.py +++ b/debug_manifest/debug_manifest.py @@ -19,7 +19,7 @@ def debug_manifest(source: YamlDeclarativeSource, args: list[str]) -> None: """ Run the debug manifest with the given source and arguments. """ - launch(source, args) + source.launch_with_cli_args(args) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index 938145e95..57af48f55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -199,7 +199,7 @@ log_cli_level = "INFO" log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" log_cli_date_format = "%Y-%m-%d %H:%M:%S" filterwarnings = [ - "ignore::airbyte_cdk.sources.source.ExperimentalClassWarning" + # "ignore::airbyte_cdk.sources.source.ExperimentalClassWarning" ] [tool.airbyte_ci] diff --git a/ruff-next.toml b/ruff-next.toml new file mode 100644 index 000000000..6c5ca332d --- /dev/null +++ b/ruff-next.toml @@ -0,0 +1,29 @@ +# Ruff configuration (extended for development guidance) +# +# To use these in your IDE, add this snippet to `.vscode/settings.json` +# in this repo: +# +# ```jsonc +# { +# // ... +# "ruff.configuration": "ruff-next.toml", +# } +target-version = "py310" +line-length = 100 + +[lint] +# These will block CI if they fail. +select = [ + "I", +] + +# These can't pass yet, but we can enabled them in the IDE for guidance. +extend-select = [ + "E", + "W", + "C", + "F", + "N", + "D", + "TC", +] diff --git a/ruff.toml b/ruff.toml index 5ed2f45e2..e43d60a09 100644 --- a/ruff.toml +++ b/ruff.toml @@ -1,6 +1,22 @@ -# Ruff configuration moved from pyproject.toml +# Ruff configuration (default for CI checks) target-version = "py310" line-length = 100 -[lint] -select = ["I"] +[lint] # NOTE: When updating rules here, please also update the `ruff-next.toml` file. + +# The following will block CI if they fail: +select = [ + "I", +] + +# These can't pass yet, so we disable them for now. +# See ruff-next.toml for a config that includes these rules. +# extend-select = [ +# "E", +# "W", +# "C", +# "F", +# "N", +# "D", +# "TC", +# ] diff --git a/unit_tests/destinations/test_destination.py b/unit_tests/destinations/test_destination.py index 1f8f6573f..4c4a6e991 100644 --- a/unit_tests/destinations/test_destination.py +++ b/unit_tests/destinations/test_destination.py @@ -32,6 +32,7 @@ SyncMode, Type, ) +from airbyte_cdk.utils.cli_arg_parse import ConnectorCLIArgs, parse_cli_args @pytest.fixture(name="destination") @@ -57,7 +58,7 @@ class TestArgParsing: def test_successful_parse( self, arg_list: List[str], expected_output: Mapping[str, Any], destination: Destination ): - parsed_args = vars(destination.parse_args(arg_list)) + parsed_args: ConnectorCLIArgs = parse_cli_args(arg_list) assert parsed_args == expected_output, ( f"Expected parsing {arg_list} to return parsed args {expected_output} but instead found {parsed_args}" ) @@ -80,7 +81,7 @@ def test_failed_parse(self, arg_list: List[str], destination: Destination): # We use BaseException because it encompasses SystemExit (raised by failed parsing) and other exceptions (raised by additional semantic # checks) with pytest.raises(BaseException): - destination.parse_args(arg_list) + parse_cli_args(arg_list) def _state(state: Dict[str, Any]) -> AirbyteStateMessage: @@ -156,7 +157,6 @@ def test_run_initializes_exception_handler( destination: Destination, ) -> None: mocker.patch.object(destination_module, "init_uncaught_exception_handler") - mocker.patch.object(destination, "parse_args") mocker.patch.object(destination, "run_cmd") destination.run(["dummy"]) destination_module.init_uncaught_exception_handler.assert_called_once_with( diff --git a/unit_tests/sources/file_based/scenarios/scenario_builder.py b/unit_tests/sources/file_based/scenarios/scenario_builder.py index 93e8f952a..f63e29d02 100644 --- a/unit_tests/sources/file_based/scenarios/scenario_builder.py +++ b/unit_tests/sources/file_based/scenarios/scenario_builder.py @@ -13,7 +13,7 @@ ConfiguredAirbyteCatalogSerializer, SyncMode, ) -from airbyte_cdk.sources import AbstractSource +from airbyte_cdk.sources.abstract_source import AbstractSource from airbyte_cdk.sources.source import TState diff --git a/unit_tests/sources/file_based/test_scenarios.py b/unit_tests/sources/file_based/test_scenarios.py index d70b7f4ef..04e00eb58 100644 --- a/unit_tests/sources/file_based/test_scenarios.py +++ b/unit_tests/sources/file_based/test_scenarios.py @@ -19,7 +19,7 @@ ConfiguredAirbyteCatalogSerializer, SyncMode, ) -from airbyte_cdk.sources import AbstractSource +from airbyte_cdk.sources.abstract_source import AbstractSource from airbyte_cdk.sources.file_based.stream.concurrent.cursor import ( AbstractConcurrentFileBasedCursor, ) @@ -227,10 +227,7 @@ def verify_check( def spec(capsys: CaptureFixture[str], scenario: TestScenario[AbstractSource]) -> Mapping[str, Any]: - launch( - scenario.source, - ["spec"], - ) + scenario.source.launch_with_cli_args(["spec"]) captured = capsys.readouterr() return json.loads(captured.out.splitlines()[0])["spec"] # type: ignore @@ -238,9 +235,12 @@ def spec(capsys: CaptureFixture[str], scenario: TestScenario[AbstractSource]) -> def check( capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource] ) -> Dict[str, Any]: - launch( - scenario.source, - ["check", "--config", make_file(tmp_path / "config.json", scenario.config)], + scenario.source.launch_with_cli_args( + [ + "check", + "--config", + make_file(tmp_path / "config.json", scenario.config), + ] ) captured = capsys.readouterr() return _find_connection_status(captured.out.splitlines()) @@ -257,8 +257,7 @@ def _find_connection_status(output: List[str]) -> Mapping[str, Any]: def discover( capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource] ) -> Dict[str, Any]: - launch( - scenario.source, + scenario.source.launch_with_cli_args( ["discover", "--config", make_file(tmp_path / "config.json", scenario.config)], ) output = [json.loads(line) for line in capsys.readouterr().out.splitlines()] diff --git a/unit_tests/sources/fixtures/source_test_fixture.py b/unit_tests/sources/fixtures/source_test_fixture.py index 3c2183b68..90158a25a 100644 --- a/unit_tests/sources/fixtures/source_test_fixture.py +++ b/unit_tests/sources/fixtures/source_test_fixture.py @@ -18,7 +18,7 @@ DestinationSyncMode, SyncMode, ) -from airbyte_cdk.sources import AbstractSource +from airbyte_cdk.sources.abstract_source import AbstractSource from airbyte_cdk.sources.streams import Stream from airbyte_cdk.sources.streams.http import HttpStream from airbyte_cdk.sources.streams.http.requests_native_auth import Oauth2Authenticator diff --git a/unit_tests/sources/mock_server_tests/mock_source_fixture.py b/unit_tests/sources/mock_server_tests/mock_source_fixture.py index 5ca7ae7cd..f762df410 100644 --- a/unit_tests/sources/mock_server_tests/mock_source_fixture.py +++ b/unit_tests/sources/mock_server_tests/mock_source_fixture.py @@ -11,7 +11,8 @@ from requests import HTTPError from airbyte_cdk.models import ConnectorSpecification, SyncMode -from airbyte_cdk.sources import AbstractSource, Source +from airbyte_cdk.sources.abstract_source import AbstractSource +from airbyte_cdk.sources.source import Source from airbyte_cdk.sources.streams import CheckpointMixin, IncrementalMixin, Stream from airbyte_cdk.sources.streams.core import StreamData from airbyte_cdk.sources.streams.http import HttpStream diff --git a/unit_tests/sources/test_abstract_source.py b/unit_tests/sources/test_abstract_source.py index 4ca7f7fb6..eb819513a 100644 --- a/unit_tests/sources/test_abstract_source.py +++ b/unit_tests/sources/test_abstract_source.py @@ -49,7 +49,7 @@ Type, ) from airbyte_cdk.models import Type as MessageType -from airbyte_cdk.sources import AbstractSource +from airbyte_cdk.sources.abstract_source import AbstractSource from airbyte_cdk.sources.message import MessageRepository from airbyte_cdk.sources.streams import IncrementalMixin, Stream from airbyte_cdk.sources.utils.record_helper import stream_data_to_airbyte_message diff --git a/unit_tests/sources/test_integration_source.py b/unit_tests/sources/test_integration_source.py index 39573fd35..6940d87bb 100644 --- a/unit_tests/sources/test_integration_source.py +++ b/unit_tests/sources/test_integration_source.py @@ -77,11 +77,11 @@ def test_external_request_source( args = ["read", "--config", "config.json", "--catalog", "configured_catalog.json"] if expected_error: with pytest.raises(AirbyteTracedException): - launch(source, args) + source.launch_with_cli_args(args) messages = [json.loads(line) for line in capsys.readouterr().out.splitlines()] assert contains_error_trace_message(messages, expected_error) else: - launch(source, args) + source.launch_with_cli_args(args) @pytest.mark.parametrize( @@ -138,11 +138,11 @@ def test_external_oauth_request_source( args = ["read", "--config", "config.json", "--catalog", "configured_catalog.json"] if expected_error: with pytest.raises(AirbyteTracedException): - launch(source, args) + source.launch_with_cli_args(args) messages = [json.loads(line) for line in capsys.readouterr().out.splitlines()] assert contains_error_trace_message(messages, expected_error) else: - launch(source, args) + source.launch_with_cli_args(args) def contains_error_trace_message(messages: List[Mapping[str, Any]], expected_error: str) -> bool: diff --git a/unit_tests/sources/test_source.py b/unit_tests/sources/test_source.py index 9554d2242..e05bfa7d3 100644 --- a/unit_tests/sources/test_source.py +++ b/unit_tests/sources/test_source.py @@ -25,7 +25,8 @@ SyncMode, Type, ) -from airbyte_cdk.sources import AbstractSource, Source +from airbyte_cdk.sources.abstract_source import AbstractSource +from airbyte_cdk.sources.source import Source from airbyte_cdk.sources.streams.core import Stream from airbyte_cdk.sources.streams.http.http import HttpStream from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer diff --git a/unit_tests/sources/test_source_read.py b/unit_tests/sources/test_source_read.py index a25f54a5a..c89f963c3 100644 --- a/unit_tests/sources/test_source_read.py +++ b/unit_tests/sources/test_source_read.py @@ -22,7 +22,7 @@ TraceType, ) from airbyte_cdk.models import Type as MessageType -from airbyte_cdk.sources import AbstractSource +from airbyte_cdk.sources.abstract_source import AbstractSource from airbyte_cdk.sources.concurrent_source.concurrent_source import ConcurrentSource from airbyte_cdk.sources.concurrent_source.concurrent_source_adapter import ConcurrentSourceAdapter from airbyte_cdk.sources.message import InMemoryMessageRepository diff --git a/unit_tests/test_connector.py b/unit_tests/test_connector.py index cf10dba01..e7e77478a 100644 --- a/unit_tests/test_connector.py +++ b/unit_tests/test_connector.py @@ -14,7 +14,7 @@ import pytest import yaml -from airbyte_cdk import Connector +from airbyte_cdk.connector import BaseConnector as Connector from airbyte_cdk.models import AirbyteConnectionStatus logger = logging.getLogger("airbyte") diff --git a/unit_tests/test_entrypoint.py b/unit_tests/test_entrypoint.py index 520131881..ec69b8135 100644 --- a/unit_tests/test_entrypoint.py +++ b/unit_tests/test_entrypoint.py @@ -43,8 +43,8 @@ TraceType, Type, ) -from airbyte_cdk.sources import Source from airbyte_cdk.sources.connector_state_manager import HashableStreamDescriptor +from airbyte_cdk.sources.source import Source from airbyte_cdk.utils import AirbyteTracedException diff --git a/unit_tests/test_secure_logger.py b/unit_tests/test_secure_logger.py index 757a069c7..d4dc865c1 100644 --- a/unit_tests/test_secure_logger.py +++ b/unit_tests/test_secure_logger.py @@ -18,7 +18,7 @@ ConnectorSpecification, Type, ) -from airbyte_cdk.sources import Source +from airbyte_cdk.sources.source import Source SECRET_PROPERTY = "api_token" ANOTHER_SECRET_PROPERTY = "another_api_token"