diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index fe3eade..45f2a4f 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -37,3 +37,9 @@ jobs: - name: Runs tests with coverage run: nosetests --with-doctest -v --nocapture + + - name: Installs mypy types where available + run: mypy --install-types + + - name: Runs mypy linter and type checks + run: mypy --config-file ./mypy.ini singer diff --git a/.gitignore b/.gitignore index 03e17bb..6ea0045 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,9 @@ __pycache__/ # C extensions *.so +# Developer IDE settings +.vscode/settings.json + # Distribution / packaging .Python env/ diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..be866f5 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,17 @@ +[mypy] +python_version = 3.8 +warn_unused_configs = True +warn_return_any = True +exclude = tests + +# No typing provided in current version of jsonschema, backoff, and cisco8601. +# This supresses the related mypy warnings: + +[mypy-jsonschema.*] +ignore_missing_imports = True + +[mypy-backoff.*] +ignore_missing_imports = True + +[mypy-ciso8601.*] +ignore_missing_imports = True diff --git a/setup.py b/setup.py index 9de29ab..1eb8099 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ 'ipython', 'ipdb', 'nose', + 'mypy', 'unify==0.5' ] }, diff --git a/singer/bookmarks.py b/singer/bookmarks.py index 53b54ca..bf4789f 100644 --- a/singer/bookmarks.py +++ b/singer/bookmarks.py @@ -1,4 +1,8 @@ -def ensure_bookmark_path(state, path): +from typing import Dict, List, Optional, Any, cast + +StateDict = Dict[str, Any] + +def ensure_bookmark_path(state: StateDict, path: List[str]) -> StateDict: submap = state for path_component in path: if submap.get(path_component) is None: @@ -7,40 +11,46 @@ def ensure_bookmark_path(state, path): submap = submap[path_component] return state -def write_bookmark(state, tap_stream_id, key, val): +def write_bookmark( + state: StateDict, tap_stream_id: str, key: str, val: Any +) -> StateDict: state = ensure_bookmark_path(state, ['bookmarks', tap_stream_id]) state['bookmarks'][tap_stream_id][key] = val return state -def clear_bookmark(state, tap_stream_id, key): +def clear_bookmark(state: StateDict, tap_stream_id: str, key: str) -> StateDict: state = ensure_bookmark_path(state, ['bookmarks', tap_stream_id]) state['bookmarks'][tap_stream_id].pop(key, None) return state -def reset_stream(state, tap_stream_id): +def reset_stream(state: StateDict, tap_stream_id: str) -> StateDict: state = ensure_bookmark_path(state, ['bookmarks', tap_stream_id]) state['bookmarks'][tap_stream_id] = {} return state -def get_bookmark(state, tap_stream_id, key, default=None): +def get_bookmark( + state: StateDict, tap_stream_id: str, key: str, default: Optional[Any] = None +) -> Optional[Any]: return state.get('bookmarks', {}).get(tap_stream_id, {}).get(key, default) -def set_offset(state, tap_stream_id, offset_key, offset_value): +def set_offset( + state: StateDict, tap_stream_id: str, offset_key: Any, offset_value: Any +): state = ensure_bookmark_path(state, ['bookmarks', tap_stream_id, 'offset', offset_key]) state['bookmarks'][tap_stream_id]['offset'][offset_key] = offset_value return state -def clear_offset(state, tap_stream_id): +def clear_offset(state: StateDict, tap_stream_id: str): state = ensure_bookmark_path(state, ['bookmarks', tap_stream_id, 'offset']) state['bookmarks'][tap_stream_id]['offset'] = {} return state -def get_offset(state, tap_stream_id, default=None): +def get_offset(state: StateDict, tap_stream_id: str, default: Optional[Any] = None): return state.get('bookmarks', {}).get(tap_stream_id, {}).get('offset', default) -def set_currently_syncing(state, tap_stream_id): +def set_currently_syncing(state: StateDict, tap_stream_id: str): state['currently_syncing'] = tap_stream_id return state -def get_currently_syncing(state, default=None): - return state.get('currently_syncing', default) +def get_currently_syncing(state: StateDict, default: str = None) -> Optional[str]: + return cast(str, state.get('currently_syncing', default)) diff --git a/singer/catalog.py b/singer/catalog.py index 77424a9..aa7f6c7 100644 --- a/singer/catalog.py +++ b/singer/catalog.py @@ -1,6 +1,7 @@ '''Provides an object model for a Singer Catalog.''' import orjson import sys +from typing import Iterable, Optional, List, Any, cast from . import metadata as metadata_module from .bookmarks import get_currently_syncing @@ -10,7 +11,7 @@ LOGGER = get_logger() -def write_catalog(catalog): +def write_catalog(catalog: "Catalog") -> None: # If the catalog has no streams, log a warning if not catalog.streams: LOGGER.warning('Catalog being written with no streams.') @@ -22,10 +23,21 @@ def write_catalog(catalog): # pylint: disable=too-many-instance-attributes class CatalogEntry(): - def __init__(self, tap_stream_id=None, stream=None, - key_properties=None, schema=None, replication_key=None, - is_view=None, database=None, table=None, row_count=None, - stream_alias=None, metadata=None, replication_method=None): + def __init__( + self, + tap_stream_id: Optional[str] = None, + stream: Optional[str] = None, + key_properties: Optional[List[str]] = None, + schema: Optional[Schema] = None, + replication_key: Optional[str] = None, + is_view: Optional[bool] = None, + database: Optional[str] = None, + table: Optional[str] = None, + row_count: Optional[int] = None, + stream_alias: Optional[str] = None, + metadata: Optional[dict] = None, + replication_method: Optional[str] = None + ) -> None: self.tap_stream_id = tap_stream_id self.stream = stream @@ -83,22 +95,22 @@ def to_dict(self): class Catalog(): - def __init__(self, streams): + def __init__(self, streams: List[CatalogEntry]) -> None: self.streams = streams - def __str__(self): + def __str__(self) -> str: return str(self.__dict__) - def __eq__(self, other): - return self.__dict__ == other.__dict__ + def __eq__(self, other: Any) -> bool: + return cast(bool, self.__dict__ == other.__dict__) @classmethod - def load(cls, filename): + def load(cls, filename: str) -> "Catalog": with open(filename, encoding='utf-8') as fp: # pylint: disable=invalid-name return Catalog.from_dict(orjson.loads(fp.read())) @classmethod - def from_dict(cls, data): + def from_dict(cls, data: dict) -> "Catalog": # TODO: We may want to store streams as a dict where the key is a # tap_stream_id and the value is a CatalogEntry. This will allow # faster lookup based on tap_stream_id. This would be a breaking @@ -121,19 +133,19 @@ def from_dict(cls, data): streams.append(entry) return Catalog(streams) - def to_dict(self): + def to_dict(self) -> dict: return {'streams': [stream.to_dict() for stream in self.streams]} - def dump(self): + def dump(self) -> None: write_catalog(self) - def get_stream(self, tap_stream_id): + def get_stream(self, tap_stream_id: str) -> Optional[CatalogEntry]: for stream in self.streams: if stream.tap_stream_id == tap_stream_id: return stream return None - def _shuffle_streams(self, state): + def _shuffle_streams(self, state: dict) -> List[CatalogEntry]: currently_syncing = get_currently_syncing(state) if currently_syncing is None: @@ -149,7 +161,7 @@ def _shuffle_streams(self, state): return top_half + bottom_half - def get_selected_streams(self, state): + def get_selected_streams(self, state: dict) -> Iterable[CatalogEntry]: for stream in self._shuffle_streams(state): if not stream.is_selected(): LOGGER.info('Skipping stream: %s', stream.tap_stream_id) diff --git a/singer/logger.py b/singer/logger.py index 2453eb9..ab97799 100644 --- a/singer/logger.py +++ b/singer/logger.py @@ -3,7 +3,7 @@ import os -def get_logger(name='singer'): +def get_logger(name: str = 'singer') -> logging.Logger: """Return a Logger instance to use in singer.""" # Use custom logging config provided by environment variable if 'LOGGING_CONF_FILE' in os.environ and os.environ['LOGGING_CONF_FILE']: diff --git a/singer/messages.py b/singer/messages.py index d0294ca..cb368c9 100644 --- a/singer/messages.py +++ b/singer/messages.py @@ -1,4 +1,6 @@ import sys +from typing import Dict, List, Any, Optional, Union +from datetime import datetime import pytz import orjson @@ -11,10 +13,10 @@ class Message(): '''Base class for messages.''' - def asdict(self): # pylint: disable=no-self-use + def asdict(self) -> dict: # pylint: disable=no-self-use raise Exception('Not implemented') - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return isinstance(other, Message) and self.asdict() == other.asdict() def __repr__(self): @@ -89,7 +91,13 @@ class SchemaMessage(Message): key_properties=['id']) ''' - def __init__(self, stream, schema, key_properties, bookmark_properties=None): + def __init__( + self, + stream: str, + schema: dict, + key_properties: Optional[List[str]], + bookmark_properties: Optional[List[str]] = None, + ) -> None: self.stream = stream self.schema = schema self.key_properties = key_properties @@ -101,7 +109,7 @@ def __init__(self, stream, schema, key_properties, bookmark_properties=None): self.bookmark_properties = bookmark_properties - def asdict(self): + def asdict(self) -> dict: result = { 'type': 'SCHEMA', 'stream': self.stream, @@ -124,10 +132,10 @@ class StateMessage(Message): value={'users': '2017-06-19T00:00:00'}) ''' - def __init__(self, value): + def __init__(self, value: dict) -> None: self.value = value - def asdict(self): + def asdict(self) -> dict: return { 'type': 'STATE', 'value': self.value @@ -155,7 +163,7 @@ class ActivateVersionMessage(Message): version=2) ''' - def __init__(self, stream, version): + def __init__(self, stream: str, version: int) -> None: self.stream = stream self.version = version @@ -194,9 +202,14 @@ class BatchMessage(Message): """ def __init__( - self, stream, filepath, file_format=None, compression=None, - batch_size=None, time_extracted=None - ): + self, + stream: str, + filepath: str, + file_format: Optional[str] = None, + compression: Optional[str] = None, + batch_size: Optional[int] = None, + time_extracted: Optional[datetime] = None, + ) -> None: self.stream = stream self.filepath = filepath self.format = file_format or 'jsonl' @@ -207,8 +220,8 @@ def __init__( raise ValueError("'time_extracted' must be either None " + 'or an aware datetime (with a time zone)') - def asdict(self): - result = { + def asdict(self) -> dict: + result: Dict[str, Any] = { 'type': 'BATCH', 'stream': self.stream, 'filepath': self.filepath, @@ -224,14 +237,14 @@ def asdict(self): return result -def _required_key(msg, k): +def _required_key(msg: dict, k: str) -> Any: if k not in msg: raise Exception(f"Message is missing required key '{k}': {msg}") return msg[k] -def parse_message(msg): +def parse_message(msg: str) -> Optional[Message]: """Parse a message string into a Message object.""" # We are not using Decimals for parsing here. @@ -292,16 +305,19 @@ def parse_message(msg): return None -def format_message(message, option=0): +def format_message(message: Message, option: int = 0) -> str: return orjson.dumps(message.asdict(), option=option) -def write_message(message): +def write_message(message: Message) -> None: sys.stdout.buffer.write(format_message(message, option=orjson.OPT_APPEND_NEWLINE)) sys.stdout.buffer.flush() -def write_record(stream_name, record, stream_alias=None, time_extracted=None): +def write_record( + stream_name: str, record: dict, stream_alias: Optional[str] = None, + time_extracted: Union[datetime, str, None] = None, +) -> None: """Write a single record for the given stream. write_record("users", {"id": 2, "email": "mike@stitchdata.com"}) @@ -311,7 +327,7 @@ def write_record(stream_name, record, stream_alias=None, time_extracted=None): time_extracted=time_extracted)) -def write_records(stream_name, records): +def write_records(stream_name: str, records: List[dict]) -> None: """Write a list of records for the given stream. chris = {"id": 1, "email": "chris@stitchdata.com"} @@ -322,7 +338,10 @@ def write_records(stream_name, records): write_record(stream_name, record) -def write_schema(stream_name, schema, key_properties, bookmark_properties=None, stream_alias=None): +def write_schema( + stream_name: str, schema: dict, key_properties: List[str], + bookmark_properties: Optional[List[str]] = None, stream_alias: Optional[str] = None +) -> None: """Write a schema message. stream = 'test' @@ -343,7 +362,7 @@ def write_schema(stream_name, schema, key_properties, bookmark_properties=None, bookmark_properties=bookmark_properties)) -def write_state(value): +def write_state(value: dict) -> None: """Write a state message. write_state({'last_updated_at': '2017-02-14T09:21:00'}) @@ -351,7 +370,7 @@ def write_state(value): write_message(StateMessage(value=value)) -def write_version(stream_name, version): +def write_version(stream_name: str, version: int) -> None: """Write an activate version message. stream = 'test' @@ -361,9 +380,13 @@ def write_version(stream_name, version): write_message(ActivateVersionMessage(stream_name, version)) def write_batch( - stream_name, filepath, file_format=None, - compression=None, batch_size=None, time_extracted=None -): + stream_name: str, + filepath: str, + file_format: Optional[str] = None, + compression: Optional[str] = None, + batch_size: Optional[int] = None, + time_extracted: Optional[datetime] = None, +) -> None: """Write a batch message. stream = 'users' diff --git a/singer/metadata.py b/singer/metadata.py index 41153ea..0693dcb 100644 --- a/singer/metadata.py +++ b/singer/metadata.py @@ -1,30 +1,53 @@ -def new(): +from typing import Any, Dict, Optional, Tuple, List + +Breadcrumb = Tuple[str, ...] +CompiledMetadata = Dict[Breadcrumb, dict] + +def new() -> CompiledMetadata: return {} -def to_map(raw_metadata): +def to_map(raw_metadata: List[dict]) -> CompiledMetadata: return {tuple(md['breadcrumb']): md['metadata'] for md in raw_metadata} -def to_list(compiled_metadata): +def to_list(compiled_metadata: CompiledMetadata) -> List[dict]: return [{'breadcrumb': k, 'metadata': v} for k, v in compiled_metadata.items()] -def delete(compiled_metadata, breadcrumb, k): +def delete( + compiled_metadata: CompiledMetadata, + breadcrumb: Breadcrumb, + k: str +) -> None: del compiled_metadata[breadcrumb][k] -def write(compiled_metadata, breadcrumb, k, val): +def write( + compiled_metadata: CompiledMetadata, + breadcrumb: Breadcrumb, + k: str, + val: Any +) -> CompiledMetadata: if val is None: raise Exception() if breadcrumb in compiled_metadata: - compiled_metadata.get(breadcrumb).update({k: val}) + compiled_metadata[breadcrumb].update({k: val}) else: compiled_metadata[breadcrumb] = {k: val} return compiled_metadata -def get(compiled_metadata, breadcrumb, k): +def get( + compiled_metadata: CompiledMetadata, + breadcrumb: Breadcrumb, + k: str +) -> Any: return compiled_metadata.get(breadcrumb, {}).get(k) -def get_standard_metadata(schema=None, schema_name=None, key_properties=None, - valid_replication_keys=None, replication_method=None): - mdata = {} +def get_standard_metadata( + schema: Optional[dict] = None, + schema_name: Optional[str] = None, + key_properties: Optional[List[str]] = None, + valid_replication_keys: Optional[List[str]] = None, + replication_method: Optional[str] = None +) -> List[dict]: + mdata: CompiledMetadata = {} if key_properties is not None: mdata = write(mdata, (), 'table-key-properties', key_properties) diff --git a/singer/metrics.py b/singer/metrics.py index 93d6e64..baac3ee 100644 --- a/singer/metrics.py +++ b/singer/metrics.py @@ -44,6 +44,8 @@ import re import time from collections import namedtuple +from typing import Optional, cast +from logging import Logger from singer.logger import get_logger DEFAULT_LOG_INTERVAL = 60 @@ -51,32 +53,32 @@ class Status: '''Constants for status codes''' - succeeded = 'succeeded' - failed = 'failed' + succeeded: str = 'succeeded' + failed: str = 'failed' class Metric: '''Constants for metric names''' - record_count = 'record_count' - job_duration = 'job_duration' - http_request_duration = 'http_request_duration' + record_count: str = 'record_count' + job_duration: str = 'job_duration' + http_request_duration: str = 'http_request_duration' class Tag: '''Constants for commonly used tags''' - endpoint = 'endpoint' - job_type = 'job_type' - http_status_code = 'http_status_code' - status = 'status' + endpoint: str = 'endpoint' + job_type: str = 'job_type' + http_status_code: str = 'http_status_code' + status: str = 'status' Point = namedtuple('Point', ['metric_type', 'metric', 'value', 'tags']) -def log(logger, point): +def log(logger: Logger, point: Point) -> None: '''Log a single data point.''' result = { 'type': point.metric_type, @@ -113,7 +115,9 @@ class Counter(): ''' - def __init__(self, metric, tags=None, log_interval=DEFAULT_LOG_INTERVAL): + def __init__( + self, metric: str, tags=None, log_interval=DEFAULT_LOG_INTERVAL + ) -> None: self.metric = metric self.value = 0 self.tags = tags if tags else {} @@ -121,22 +125,22 @@ def __init__(self, metric, tags=None, log_interval=DEFAULT_LOG_INTERVAL): self.logger = get_logger() self.last_log_time = time.time() - def __enter__(self): + def __enter__(self) -> "Counter": self.last_log_time = time.time() return self - def increment(self, amount=1): + def increment(self, amount=1) -> None: '''Increments value by the specified amount.''' self.value += amount if self._ready_to_log(): self._pop() - def _pop(self): + def _pop(self) -> None: log(self.logger, Point('counter', self.metric, self.value, self.tags)) self.value = 0 self.last_log_time = time.time() - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback) -> None: self._pop() def _ready_to_log(self): @@ -170,19 +174,19 @@ class Timer(): # pylint: disable=too-few-public-methods }, ''' - def __init__(self, metric, tags): + def __init__(self, metric: str, tags) -> None: self.metric = metric self.tags = tags if tags else {} self.logger = get_logger() - self.start_time = None + self.start_time: Optional[float] = None - def __enter__(self): + def __enter__(self) -> "Timer": self.start_time = time.time() return self - def elapsed(self): + def elapsed(self) -> float: '''Return elapsed time''' - return time.time() - self.start_time + return time.time() - cast(float, self.start_time) # Assumes not null def __exit__(self, exc_type, exc_value, traceback): if Tag.status not in self.tags: @@ -193,7 +197,7 @@ def __exit__(self, exc_type, exc_value, traceback): log(self.logger, Point('timer', self.metric, self.elapsed(), self.tags)) -def record_counter(endpoint=None, log_interval=DEFAULT_LOG_INTERVAL): +def record_counter(endpoint=None, log_interval=DEFAULT_LOG_INTERVAL) -> Counter: '''Use for counting records retrieved from the source. with singer.metrics.record_counter(endpoint="users") as counter: @@ -231,7 +235,7 @@ def job_timer(job_type=None): return Timer(Metric.job_duration, tags) -def parse(line): +def parse(line: str) -> Optional[Point]: '''Parse a Point from a log line and return it, or None if no data point.''' match = re.match(r'^INFO METRIC: (.*)$', line) if match: diff --git a/singer/schema.py b/singer/schema.py index 108f50f..0f5bbe4 100644 --- a/singer/schema.py +++ b/singer/schema.py @@ -1,6 +1,8 @@ # pylint: disable=redefined-builtin, too-many-arguments, invalid-name '''Provides an object model for JSON Schema''' +from typing import Any, Dict, cast + import orjson # These are standard keys defined in the JSON Schema spec @@ -32,11 +34,26 @@ class Schema(): # pylint: disable=too-many-instance-attributes ''' # pylint: disable=too-many-locals - def __init__(self, type=None, format=None, properties=None, items=None, - selected=None, inclusion=None, description=None, minimum=None, - maximum=None, exclusiveMinimum=None, exclusiveMaximum=None, - multipleOf=None, maxLength=None, minLength=None, additionalProperties=None, - anyOf=None, patternProperties=None): + def __init__( + self, + type=None, + format=None, + properties=None, + items=None, + selected=None, + inclusion=None, + description=None, + minimum=None, + maximum=None, + exclusiveMinimum=None, + exclusiveMaximum=None, + multipleOf=None, + maxLength=None, + minLength=None, + additionalProperties=None, + anyOf=None, + patternProperties=None + ) -> None: self.type = type self.properties = properties @@ -56,20 +73,20 @@ def __init__(self, type=None, format=None, properties=None, items=None, self.additionalProperties = additionalProperties self.patternProperties = patternProperties - def __str__(self): + def __str__(self) -> str: return orjson.dumps(self.to_dict()).decode('utf-8') - def __repr__(self): + def __repr__(self) -> str: pairs = [k + '=' + repr(v) for k, v in self.__dict__.items()] args = ', '.join(pairs) return 'Schema(' + args + ')' - def __eq__(self, other): - return self.__dict__ == other.__dict__ + def __eq__(self, other: Any) -> bool: + return cast(bool, self.__dict__ == other.__dict__) - def to_dict(self): + def to_dict(self) -> dict: '''Return the raw JSON Schema as a (possibly nested) dict.''' - result = {} + result: Dict[str, Any] = {} if self.properties is not None: result['properties'] = { @@ -89,7 +106,7 @@ def to_dict(self): return result @classmethod - def from_dict(cls, data, **schema_defaults): + def from_dict(cls, data: dict, **schema_defaults) -> "Schema": '''Initialize a Schema object based on the JSON Schema structure. :param schema_defaults: The default values to the Schema diff --git a/singer/statediff.py b/singer/statediff.py index bc21fd5..f40bb3b 100644 --- a/singer/statediff.py +++ b/singer/statediff.py @@ -1,11 +1,14 @@ import collections +from typing import Sequence, Union, List, Any, Tuple # Named tuples for holding add, change, and remove operations Add = collections.namedtuple('Add', ['path', 'newval']) Change = collections.namedtuple('Change', ['path', 'oldval', 'newval']) Remove = collections.namedtuple('Remove', ['path', 'oldval']) -def paths(data, base=None): +def paths( + data: Union[list, dict, Any], base: Tuple[Any, ...] = None +) -> List[Tuple[Tuple[Any, ...], Any]]: '''Walk a data structure and return a list of (path, value) tuples, where each path is the path to a leaf node in the data structure and the value is the value it points to. Each path will be a tuple. @@ -28,7 +31,7 @@ def paths(data, base=None): return result -def diff(oldstate, newstate): +def diff(oldstate: dict, newstate: dict) -> Sequence[Union[Add, Change, Remove]]: '''Compare two states, returning a list of Add, Change, and Remove objects. @@ -55,7 +58,7 @@ def diff(oldstate, newstate): all_paths.update(set(olddict.keys())) all_paths.update(set(newdict.keys())) - result = [] + result: List[Any] = [] for path in sorted(all_paths): if path in olddict: if path in newdict: diff --git a/singer/transform.py b/singer/transform.py index f117570..cb42340 100644 --- a/singer/transform.py +++ b/singer/transform.py @@ -1,6 +1,7 @@ import datetime import logging import re +from typing import List, Optional, Set, Union, Any, Tuple from jsonschema import RefResolver import singer.metadata @@ -20,7 +21,7 @@ ] -def string_to_datetime(value): +def string_to_datetime(value: str) -> Optional[str]: try: return strftime(strptime_to_utc(value)) except Exception as ex: @@ -28,16 +29,18 @@ def string_to_datetime(value): return None -def unix_milliseconds_to_datetime(value): - return strftime(datetime.datetime.fromtimestamp(float(value) / 1000.0, datetime.timezone.utc)) +def unix_milliseconds_to_datetime(value: Union[str, float]) -> str: + return strftime( + datetime.datetime.fromtimestamp(float(value) / 1000.0, datetime.timezone.utc) + ) -def unix_seconds_to_datetime(value): +def unix_seconds_to_datetime(value: Union[str, int, float]) -> str: return strftime(datetime.datetime.fromtimestamp(int(value), datetime.timezone.utc)) class SchemaMismatch(Exception): - def __init__(self, errors): + def __init__(self, errors: List[Any]) -> None: if not errors: msg = 'An error occured during transform that was not a schema mismatch' @@ -55,13 +58,16 @@ class SchemaKey: any_of = 'anyOf' class Error: - def __init__(self, path, data, schema=None, logging_level=logging.INFO): + def __init__( + self, path: list, data: dict, schema: Optional[dict] = None, + logging_level=logging.INFO + ) -> None: self.path = path self.data = data self.schema = schema self.logging_level = logging_level - def tostr(self): + def tostr(self) -> str: path = '.'.join(map(str, self.path)) if self.schema: if self.logging_level >= logging.INFO: @@ -79,14 +85,16 @@ def tostr(self): class Transformer: - def __init__(self, integer_datetime_fmt=NO_INTEGER_DATETIME_PARSING, pre_hook=None): + def __init__( + self, integer_datetime_fmt=NO_INTEGER_DATETIME_PARSING, pre_hook=None + ) -> None: self.integer_datetime_fmt = integer_datetime_fmt self.pre_hook = pre_hook - self.removed = set() - self.filtered = set() - self.errors = [] + self.removed: Set[str] = set() + self.filtered: Set[str] = set() + self.errors: List[Error] = [] - def log_warning(self): + def log_warning(self) -> None: if self.filtered: LOGGER.debug('Filtered %s paths during transforms ' 'as they were unsupported or not selected:\n\t%s', @@ -103,17 +111,23 @@ def log_warning(self): # Output list format to parse for reporting LOGGER.debug('Removed paths list: %s', sorted(self.removed)) - def __enter__(self): + def __enter__(self) -> "Transformer": return self - def __exit__(self, *args): + def __exit__(self, *args) -> None: self.log_warning() - def filter_data_by_metadata(self, data, metadata): + def filter_data_by_metadata( + self, data: Union[dict, Any], metadata: Optional[dict] + ) -> dict: if isinstance(data, dict) and metadata: for field_name in list(data.keys()): - selected = singer.metadata.get(metadata, ('properties', field_name), 'selected') - inclusion = singer.metadata.get(metadata, ('properties', field_name), 'inclusion') + selected = singer.metadata.get( + metadata, ('properties', field_name), 'selected' + ) + inclusion = singer.metadata.get( + metadata, ('properties', field_name), 'inclusion' + ) if inclusion == 'automatic': continue @@ -131,7 +145,7 @@ def filter_data_by_metadata(self, data, metadata): return data - def transform(self, data, schema, metadata=None): + def transform(self, data: dict, schema: dict, metadata: Optional[dict] =None): data = self.filter_data_by_metadata(data, metadata) success, transformed_data = self.transform_recur(data, schema, []) @@ -140,7 +154,9 @@ def transform(self, data, schema, metadata=None): return transformed_data - def transform_recur(self, data, schema, path): + def transform_recur( + self, data: dict, schema: dict, path: list + ) -> Tuple[bool, Optional[dict]]: if 'anyOf' in schema: return self._transform_anyof(data, schema, path) @@ -165,7 +181,7 @@ def transform_recur(self, data, schema, path): self.errors.append(Error(path, data, schema, logging_level=LOGGER.level)) return False, None - def _transform_anyof(self, data, schema, path): + def _transform_anyof(self, data: dict, schema: dict, path: list) -> Tuple[bool, Any]: subschemas = schema['anyOf'] for subschema in subschemas: success, transformed_data = self.transform_recur(data, subschema, path) @@ -176,7 +192,9 @@ def _transform_anyof(self, data, schema, path): self.errors.append(Error(path, data, schema, logging_level=LOGGER.level)) return False, None - def _transform_object(self, data, schema, path, pattern_properties): + def _transform_object( + self, data: dict, schema: dict, path: list, pattern_properties: dict + ) -> Tuple[bool, dict]: # We do not necessarily have a dict to transform here. The schema's # type could contain multiple possible values. Eg: # ["null", "object", "string"] @@ -209,7 +227,7 @@ def _transform_object(self, data, schema, path, pattern_properties): return all(successes), result - def _transform_array(self, data, schema, path): + def _transform_array(self, data: list, schema: dict, path: list) -> Tuple[bool, list]: # We do not necessarily have a list to transform here. The schema's # type could contain multiple possible values. Eg: # ["null", "array", "integer"] @@ -224,7 +242,7 @@ def _transform_array(self, data, schema, path): return all(successes), result - def _transform_datetime(self, value): + def _transform_datetime(self, value: Optional[str]) -> Optional[str]: if value is None or value == '': return None # Short circuit in the case of null or empty string @@ -242,7 +260,7 @@ def _transform_datetime(self, value): except Exception: return string_to_datetime(value) - def _transform(self, data, typ, schema, path): + def _transform(self, data, typ, schema, path: list) -> Tuple[bool, Optional[Any]]: if self.pre_hook: data = self.pre_hook(data, typ, schema) diff --git a/singer/utils.py b/singer/utils.py index 7579280..96c0f10 100644 --- a/singer/utils.py +++ b/singer/utils.py @@ -2,8 +2,12 @@ import collections import datetime import functools +from logging import Logger + import orjson + import time +from typing import Callable, Union, List, cast, Iterable from warnings import warn import dateutil.parser @@ -16,17 +20,17 @@ DATETIME_FMT = '%04Y-%m-%dT%H:%M:%S.%fZ' DATETIME_FMT_SAFE = '%Y-%m-%dT%H:%M:%S.%fZ' -def now(): +def now() -> datetime.datetime: return datetime.datetime.utcnow().replace(tzinfo=pytz.UTC) -def strptime_with_tz(dtime): +def strptime_with_tz(dtime: str) -> datetime.datetime: d_object = dateutil.parser.parse(dtime) if d_object.tzinfo is None: return d_object.replace(tzinfo=pytz.UTC) return d_object -def strptime(dtime): +def strptime(dtime: str) -> datetime.datetime: """DEPRECATED Use strptime_to_utc instead. Parse DTIME according to DATETIME_PARSE without TZ safety. @@ -57,14 +61,14 @@ def strptime(dtime): return datetime.datetime.strptime(dtime, DATETIME_PARSE) -def strptime_to_utc(dtimestr): +def strptime_to_utc(dtimestr: str) -> datetime.datetime: d_object = dateutil.parser.parse(dtimestr) if d_object.tzinfo is None: return d_object.replace(tzinfo=pytz.UTC) return d_object.astimezone(tz=pytz.UTC) -def strftime(dtime, format_str=DATETIME_FMT): +def strftime(dtime: datetime.datetime, format_str: str = DATETIME_FMT) -> str: if dtime.utcoffset() != datetime.timedelta(0): raise Exception('datetime must be pegged at UTC tzoneinfo') @@ -78,7 +82,7 @@ def strftime(dtime, format_str=DATETIME_FMT): return dt_str -def ratelimit(limit, every): +def ratelimit(limit: int, every: int) -> Callable: def limitdecorator(func): times = collections.deque() @@ -99,17 +103,19 @@ def wrapper(*args, **kwargs): return limitdecorator -def chunk(array, num): +def chunk(array: list, num: int) -> Iterable[list]: for i in range(0, len(array), num): yield array[i:i + num] -def load_json(path): +def load_json(path: str) -> Union[dict, list]: with open(path, encoding='utf-8') as fil: return orjson.loads(fil.read()) -def update_state(state, entity, dtime): +def update_state( + state: dict, entity, dtime: Union[str, datetime.datetime] +) -> None: if dtime is None: return @@ -123,7 +129,7 @@ def update_state(state, entity, dtime): state[entity] = dtime -def parse_args(required_config_keys): +def parse_args(required_config_keys: List[str]) -> argparse.Namespace: '''Parse standard command-line args. Parses the command-line arguments mentioned in the SPEC and the @@ -184,45 +190,51 @@ def parse_args(required_config_keys): return args -def check_config(config, required_keys): +def check_config(config: dict, required_keys: List[str]) -> None: missing_keys = [key for key in required_keys if key not in config] if missing_keys: raise Exception(f'Config is missing required keys: {missing_keys}') -def backoff(exceptions, giveup): +def backoff(exceptions, giveup) -> Callable: """Decorates a function to retry up to 5 times using an exponential backoff function. exceptions is a tuple of exception classes that are retried giveup is a function that accepts the exception and returns True to retry """ - return backoff_module.on_exception( - backoff_module.expo, - exceptions, - max_tries=5, - giveup=giveup, - factor=2) - - -def exception_is_4xx(exception): + return cast( + Callable, + backoff_module.on_exception( + backoff_module.expo, + exceptions, + max_tries=5, + giveup=giveup, + factor=2 + ) + ) + + +def exception_is_4xx(exception: Exception) -> bool: """Returns True if exception is in the 4xx range.""" if not hasattr(exception, 'response'): return False - if exception.response is None: + response = exception.response # type: ignore # Duck-typed requests.Response + + if response is None: return False - if not hasattr(exception.response, 'status_code'): + if not hasattr(response, 'status_code'): return False - return 400 <= exception.response.status_code < 500 + return 400 <= cast(int, response.status_code) < 500 -def handle_top_exception(logger): +def handle_top_exception(logger: Logger) -> Callable: """A decorator that will catch exceptions and log the exception's message as a CRITICAL log.""" - def decorator(fnc): + def decorator(fnc: Callable) -> Callable: @functools.wraps(fnc) def wrapped(*args, **kwargs): try: @@ -234,7 +246,9 @@ def wrapped(*args, **kwargs): return decorator -def should_sync_field(inclusion, selected, default=False): +def should_sync_field( + inclusion: str, selected: bool, default: bool = False +) -> bool: """ Returns True if a field should be synced.