diff --git a/prance/__init__.py b/prance/__init__.py index 63222fc..ca7745b 100644 --- a/prance/__init__.py +++ b/prance/__init__.py @@ -6,14 +6,21 @@ Included is a BaseParser that reads and validates swagger specs, and a ResolvingParser that additionally resolves any $ref references. """ +import sys +from typing import Any +from typing import Dict +from typing import Optional +from typing import Union +from urllib.parse import ParseResult + +from packaging.version import Version # type: ignore[import-not-found] + +from prance.util.path import JsonValue __author__ = "Jens Finkhaeuser" __copyright__ = "Copyright (c) 2016-2021 Jens Finkhaeuser" __license__ = "MIT" __all__ = ("util", "mixins", "cli", "convert") -import sys - -from packaging.version import Version try: from prance._version import version as __version__ @@ -55,7 +62,13 @@ class BaseParser(mixins.YAMLMixin, mixins.JSONMixin): SPEC_VERSION_2_PREFIX = "Swagger/OpenAPI" SPEC_VERSION_3_PREFIX = "OpenAPI" - def __init__(self, url=None, spec_string=None, lazy=False, **kwargs): + def __init__( + self, + url: str | None = None, + spec_string: str | None = None, + lazy: bool = False, + **kwargs: Any, + ) -> None: """ Load, parse and validate specs. @@ -82,7 +95,7 @@ def __init__(self, url=None, spec_string=None, lazy=False, **kwargs): ) # Keep the parameters around for later use - self.url = None + self.url: ParseResult if url: from .util.url import absurl from .util.fs import abspath @@ -90,24 +103,26 @@ def __init__(self, url=None, spec_string=None, lazy=False, **kwargs): self.url = absurl(url, abspath(os.getcwd())) else: - self.url = _PLACEHOLDER_URL + from urllib.parse import urlparse + + self.url = urlparse(_PLACEHOLDER_URL) - self._spec_string = spec_string + self._spec_string: str | None = spec_string # Initialize variables we're filling later - self.specification = None - self.version = None - self.version_name = None - self.version_parsed = () - self.valid = False + self.specification: JsonValue | None = None + self.version: str | None = None + self.version_name: str | None = None + self.version_parsed: tuple = () + self.valid: bool = False # Add kw args as options - self.options = kwargs + self.options: dict[str, Any] = kwargs # Verify backend from .util import default_validation_backend - self.backend = self.options.get("backend", default_validation_backend()) + self.backend: str = self.options.get("backend", default_validation_backend()) if self.backend not in BaseParser.BACKENDS.keys(): raise ValueError( f"Backend may only be one of {BaseParser.BACKENDS.keys()}!" @@ -117,7 +132,7 @@ def __init__(self, url=None, spec_string=None, lazy=False, **kwargs): if not lazy: self.parse() - def parse(self): # noqa: F811 + def parse(self) -> None: # noqa: F811 """ When the BaseParser was lazily created, load and parse now. @@ -128,7 +143,7 @@ def parse(self): # noqa: F811 strict = self.options.get("strict", True) # If we have a file name, we need to read that in. - if self.url and self.url != _PLACEHOLDER_URL: + if self.url and self.url.geturl() != _PLACEHOLDER_URL: from .util.url import fetch_url encoding = self.options.get("encoding", None) @@ -138,7 +153,7 @@ def parse(self): # noqa: F811 if self._spec_string: from .util.formats import parse_spec - self.specification = parse_spec(self._spec_string, self.url) + self.specification = parse_spec(self._spec_string, self.url.path) # If we have a parsed spec, convert it to JSON. Then we can validate # the JSON. At this point, we *require* a parsed specification to exist, @@ -147,7 +162,7 @@ def parse(self): # noqa: F811 self._validate() - def _validate(self): + def _validate(self) -> None: # Ensure specification is a mapping from collections.abc import Mapping @@ -159,18 +174,22 @@ def _validate(self): # Fetch the spec version. Note that this is the spec version the spec # *claims* to be; we later set the one we actually could validate as. - spec_version = None + spec_version: str | None = None if spec_version is None: - spec_version = self.specification.get("openapi", None) + version_val = self.specification.get("openapi", None) + if isinstance(version_val, str): + spec_version = version_val if spec_version is None: - spec_version = self.specification.get("swagger", None) + version_val = self.specification.get("swagger", None) + if isinstance(version_val, str): + spec_version = version_val if spec_version is None: raise ValidationError( "Could not determine specification schema " "version!" ) # Try parsing the spec version, examine the first component. - import packaging.version + import packaging.version # type: ignore[import-not-found] parsed = packaging.version.parse(spec_version) if parsed.major not in versions: @@ -187,7 +206,7 @@ def _validate(self): validator(parsed) self.valid = True - def __set_version(self, prefix, version: Version): + def __set_version(self, prefix: str, version: Version) -> None: self.version_name = prefix self.version_parsed = version.release @@ -196,12 +215,12 @@ def __set_version(self, prefix, version: Version): stringified = "%d.%d" % (version.major, version.minor) self.version = f"{self.version_name} {stringified}" - def _validate_flex(self, spec_version: Version): # pragma: nocover + def _validate_flex(self, spec_version: Version) -> None: # pragma: nocover # Set the version independently of whether validation succeeds self.__set_version(BaseParser.SPEC_VERSION_2_PREFIX, spec_version) - from flex.exceptions import ValidationError as JSEValidationError - from flex.core import parse as validate + from flex.exceptions import ValidationError as JSEValidationError # type: ignore[import-not-found] + from flex.core import parse as validate # type: ignore[import-not-found] try: validate(self.specification) @@ -212,12 +231,12 @@ def _validate_flex(self, spec_version: Version): # pragma: nocover def _validate_swagger_spec_validator( self, spec_version: Version - ): # pragma: nocover + ) -> None: # pragma: nocover # Set the version independently of whether validation succeeds self.__set_version(BaseParser.SPEC_VERSION_2_PREFIX, spec_version) - from swagger_spec_validator.common import SwaggerValidationError as SSVErr - from swagger_spec_validator.validator20 import validate_spec + from swagger_spec_validator.common import SwaggerValidationError as SSVErr # type: ignore[import-not-found] + from swagger_spec_validator.validator20 import validate_spec # type: ignore[import-not-found] try: validate_spec(self.specification) @@ -228,10 +247,10 @@ def _validate_swagger_spec_validator( def _validate_openapi_spec_validator( self, spec_version: Version - ): # pragma: nocover - from openapi_spec_validator import validate - from jsonschema.exceptions import ValidationError as JSEValidationError - from referencing.exceptions import Unresolvable + ) -> None: # pragma: nocover + from openapi_spec_validator import validate # type: ignore[import-not-found] + from jsonschema.exceptions import ValidationError as JSEValidationError # type: ignore[import-untyped] + from referencing.exceptions import Unresolvable # type: ignore[import-not-found] # Validate according to detected version. Unsupported versions are # already caught outside of this function. @@ -253,7 +272,7 @@ def _validate_openapi_spec_validator( except Unresolvable as ref_unres: raise_from(ValidationError, ref_unres) - def _strict_warning(self): + def _strict_warning(self) -> str: """Return a warning if strict mode is off.""" if self.options.get("strict", True): return ( @@ -269,7 +288,13 @@ def _strict_warning(self): class ResolvingParser(BaseParser): """The ResolvingParser extends BaseParser with resolving references by inlining.""" - def __init__(self, url=None, spec_string=None, lazy=False, **kwargs): + def __init__( + self, + url: str | None = None, + spec_string: str | None = None, + lazy: bool = False, + **kwargs: Any, + ) -> None: """ See :py:class:`BaseParser`. @@ -280,11 +305,11 @@ def __init__(self, url=None, spec_string=None, lazy=False, **kwargs): Additional parameters, see :py::class:`util.RefResolver`. """ # Create a reference cache - self.__reference_cache = {} + self.__reference_cache: dict[str | tuple, JsonValue] = {} BaseParser.__init__(self, url=url, spec_string=spec_string, lazy=lazy, **kwargs) - def _validate(self): + def _validate(self) -> None: # We have a problem with the BaseParser's validate function: the # jsonschema implementation underlying it does not accept relative # path references, but the Swagger specs allow them: @@ -300,7 +325,7 @@ def _validate(self): "resolve_method", "strict", ) - forward_args = { + forward_args: dict[str, Any] = { k: v for (k, v) in self.options.items() if k in forward_arg_names } resolver = RefResolver( @@ -318,10 +343,10 @@ def _validate(self): # Underscored to allow some time for the public API to be stabilized. class _TranslatingParser(BaseParser): - def _validate(self): + def _validate(self) -> None: from .util.translator import _RefTranslator - translator = _RefTranslator(self.specification, self.url) + translator = _RefTranslator(self.specification, self.url.geturl()) translator.translate_references() self.specification = translator.specs diff --git a/prance/cli.py b/prance/cli.py index 8df9ce8..5bc1008 100644 --- a/prance/cli.py +++ b/prance/cli.py @@ -1,4 +1,13 @@ """CLI for prance.""" +from typing import Any +from typing import Optional +from typing import Tuple + +import click # type: ignore[import-not-found] + +import prance +from prance.util import default_validation_backend +from prance.util.path import JsonValue __author__ = "Jens Finkhaeuser" __copyright__ = "Copyright (c) 2016-2021 Jens Finkhaeuser" @@ -6,13 +15,7 @@ __all__ = () -import click - -import prance -from prance.util import default_validation_backend - - -def __write_to_file(filename, specs): # noqa: N802 +def __write_to_file(filename: str, specs: JsonValue) -> None: # noqa: N802 """ Write specs to the given filename. @@ -24,7 +27,9 @@ def __write_to_file(filename, specs): # noqa: N802 fs.write_file(filename, contents) -def __parser_for_url(url, resolve, backend, strict, encoding): # noqa: N802 +def __parser_for_url( + url: str, resolve: bool, backend: str, strict: bool, encoding: str | None +) -> tuple[prance.BaseParser, str]: # noqa: N802 """Return a parser instance for the URL and the given parameters.""" # Try the URL formatted = click.format_filename(url) @@ -39,7 +44,7 @@ def __parser_for_url(url, resolve, backend, strict, encoding): # noqa: N802 url = fsurl # Create parser to use - parser = None + parser: prance.BaseParser if resolve: click.echo(" -> Resolving external references.") parser = prance.ResolvingParser( @@ -56,7 +61,7 @@ def __parser_for_url(url, resolve, backend, strict, encoding): # noqa: N802 return parser, formatted -def __validate(parser, name): # noqa: N802 +def __validate(parser: prance.BaseParser, name: str) -> None: # noqa: N802 """Validate a spec using this parser.""" from prance.util.url import ResolutionError from prance import ValidationError @@ -76,14 +81,14 @@ def __validate(parser, name): # noqa: N802 @click.group() @click.version_option(version=prance.__version__) -def cli(): +def cli() -> None: pass # pragma: no cover class GroupWithCommandOptions(click.Group): """Allow application of options to group with multi command.""" - def add_command(self, cmd, name=None): + def add_command(self, cmd: click.Command, name: str | None = None) -> None: click.Group.add_command(self, cmd, name=name) # add the group parameters to the command @@ -94,8 +99,8 @@ def add_command(self, cmd, name=None): cmd.invoke = self.build_command_invoke(cmd.invoke) self.invoke_without_command = True - def build_command_invoke(self, original_invoke): - def command_invoke(ctx): + def build_command_invoke(self, original_invoke: Any) -> Any: + def command_invoke(ctx: click.Context) -> None: """Insert invocation of group function.""" # separate the group parameters ctx.obj = dict(_params=dict()) @@ -145,7 +150,9 @@ def command_invoke(ctx): "encoding for all files. Does not work on remote URLs.", ) @click.pass_context -def backend_options(ctx, resolve, backend, strict, encoding): +def backend_options( + ctx: click.Context, resolve: bool, backend: str, strict: bool, encoding: str | None +) -> None: ctx.obj["resolve"] = resolve ctx.obj["backend"] = backend ctx.obj["strict"] = strict @@ -171,7 +178,9 @@ def backend_options(ctx, resolve, backend, strict, encoding): nargs=-1, ) @click.pass_context -def validate(ctx, output_file, urls): +def validate( + ctx: click.Context, output_file: str | None, urls: tuple[str, ...] +) -> None: """ Validate the given spec or specs. @@ -226,7 +235,7 @@ def validate(ctx, output_file, urls): required=False, ) @click.pass_context -def compile(ctx, url_or_path, output_file): +def compile(ctx: click.Context, url_or_path: str, output_file: str | None) -> None: """ Compile the given spec, resolving references if required. @@ -273,7 +282,7 @@ def compile(ctx, url_or_path, output_file): nargs=1, required=False, ) -def convert(url_or_path, output_file): +def convert(url_or_path: str, output_file: str | None) -> None: """ Convert the given spec to OpenAPI 3.x.y. diff --git a/prance/convert.py b/prance/convert.py index 1ec498a..fec5eae 100644 --- a/prance/convert.py +++ b/prance/convert.py @@ -3,6 +3,19 @@ The functions use https://converter.swagger.io/ APIs for conversion. """ +from typing import Any +from typing import Dict +from typing import Optional +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import Union +from urllib.parse import ParseResult + +from prance.util.path import JsonValue + +if TYPE_CHECKING: + from prance import BaseParser __author__ = "Jens Finkhaeuser" __copyright__ = "Copyright (c) 2018 Jens Finkhaeuser" @@ -14,7 +27,9 @@ class ConversionError(ValueError): pass # pragma: nocover -def convert_str(spec_str, filename=None, **kwargs): +def convert_str( + spec_str: str, filename: str | None = None, **kwargs: str | None +) -> tuple[str, str]: """ Convert the serialized spec. @@ -41,7 +56,7 @@ def convert_str(spec_str, filename=None, **kwargs): headers = {"accept": content_type, "content-type": content_type} # Convert via API - import requests + import requests # type: ignore[import-untyped] r = requests.post( "https://converter.swagger.io/api/convert", data=data, headers=headers @@ -54,7 +69,9 @@ def convert_str(spec_str, filename=None, **kwargs): return r.text, "{}; {}".format(r.headers["content-type"], r.apparent_encoding) -def convert_url(url, cache={}): +def convert_url( + url: str | ParseResult, cache: dict[str, tuple[str, str | None]] | None = None +) -> tuple[str, str]: """ Fetch a URL, and try to convert it to OpenAPI 3.x.y. @@ -65,7 +82,14 @@ def convert_url(url, cache={}): :raises ConversionError: when conversion fails. """ # Fetch URL contents - from .util.url import fetch_url_text + from .util.url import absurl, fetch_url_text + + if cache is None: + cache = {} + + # Ensure url is a ParseResult + if isinstance(url, str): + url = absurl(url) content, content_type = fetch_url_text(url, cache) @@ -73,7 +97,12 @@ def convert_url(url, cache={}): return convert_str(content, None, content_type=content_type) -def convert_spec(parser_or_spec, parser_klass=None, *args, **kwargs): +def convert_spec( + parser_or_spec: Union[JsonValue, "BaseParser"], + parser_klass: type["BaseParser"] | None = None, + *args: Any, + **kwargs: Any, +) -> "BaseParser": """ Convert an already parsed spec to OpenAPI 3.x.y. @@ -104,9 +133,9 @@ def convert_spec(parser_or_spec, parser_klass=None, *args, **kwargs): :rtype: BaseParser or derived. """ # Figure out exact configuration to use - klass = None - options = None - spec = None + klass: type["BaseParser"] + options: dict[str, Any] + spec: JsonValue from . import BaseParser diff --git a/prance/mixins.py b/prance/mixins.py index 3923165..e416e06 100644 --- a/prance/mixins.py +++ b/prance/mixins.py @@ -3,6 +3,9 @@ The Mixins are here mostly for separation of concerns. """ +from typing import Any +from typing import cast +from typing import Optional __author__ = "Jens Finkhaeuser" __copyright__ = "Copyright (c) 2016-2018 Jens Finkhaeuser" @@ -17,9 +20,12 @@ class CacheSpecsMixin: It does so by caching a shallow copy on-demand. """ + # This attribute is expected to be provided by the class using this mixin + specification: Any + __CACHED_SPECS = "__cached_specs" - def specs_updated(self): + def specs_updated(self) -> bool: """ Test if self.specficiation changed. @@ -51,7 +57,7 @@ class YAMLMixin(CacheSpecsMixin): __YAML = "__yaml" - def yaml(self): + def yaml(self) -> str: """ Return a YAML representation of the specifications. @@ -60,10 +66,10 @@ def yaml(self): """ # Query specs_updated first to start caching if self.specs_updated() or not getattr(self, self.__YAML, None): - import yaml + import yaml # type: ignore[import-untyped] setattr(self, self.__YAML, yaml.dump(self.specification)) - return getattr(self, self.__YAML) + return cast(str, getattr(self, self.__YAML)) class JSONMixin(CacheSpecsMixin): @@ -75,7 +81,7 @@ class JSONMixin(CacheSpecsMixin): __JSON = "__json" - def json(self): + def json(self) -> str: """ Return a JSON representation of the specifications. @@ -87,4 +93,4 @@ def json(self): import json setattr(self, self.__JSON, json.dumps(self.specification)) - return getattr(self, self.__JSON) + return cast(str, getattr(self, self.__JSON)) diff --git a/prance/util/__init__.py b/prance/util/__init__.py index 149bb32..121a2fc 100644 --- a/prance/util/__init__.py +++ b/prance/util/__init__.py @@ -1,12 +1,19 @@ """This submodule contains utility code for Prance.""" +from collections.abc import Mapping +from collections.abc import MutableMapping +from typing import List +from typing import Tuple +from typing import TypeVar __author__ = "Jens Finkhaeuser" __copyright__ = "Copyright (c) 2016-2021 Jens Finkhaeuser" __license__ = "MIT" __all__ = ("iterators", "fs", "formats", "resolver", "url", "path", "exceptions") +MappingT = TypeVar("MappingT", bound=MutableMapping) -def stringify_keys(data): + +def stringify_keys(data: MappingT) -> MappingT: """ Recursively stringify keys in a dict-like object. @@ -22,32 +29,32 @@ def stringify_keys(data): for key, value in data.items(): if not isinstance(key, str): key = str(key) - if isinstance(value, Mapping): + if isinstance(value, MutableMapping): value = stringify_keys(value) ret[key] = value return ret -def validation_backends(): +def validation_backends() -> tuple[str, ...]: """Return a list of validation backends supported by the environment.""" - ret = [] + ret: list[str] = [] try: - import flex # noqa: F401 + import flex # type: ignore[import-not-found] # noqa: F401 ret.append("flex") # pragma: nocover except (ImportError, SyntaxError): # pragma: nocover pass try: - import openapi_spec_validator # noqa: F401 + import openapi_spec_validator # type: ignore[import-not-found] # noqa: F401 ret.append("openapi-spec-validator") # pragma: nocover except (ImportError, SyntaxError): # pragma: nocover pass try: - import swagger_spec_validator # noqa: F401 + import swagger_spec_validator # type: ignore[import-not-found] # noqa: F401 ret.append("swagger-spec-validator") # pragma: nocover except (ImportError, SyntaxError): # pragma: nocover @@ -56,7 +63,7 @@ def validation_backends(): return tuple(ret) -def default_validation_backend(): +def default_validation_backend() -> str: """Return the default validation backend, or raise an error.""" backends = validation_backends() if len(backends) <= 0: # pragma: nocover diff --git a/prance/util/exceptions.py b/prance/util/exceptions.py index 2f01ae2..d999745 100644 --- a/prance/util/exceptions.py +++ b/prance/util/exceptions.py @@ -1,4 +1,6 @@ """This submodule contains helpers for exception handling.""" +from typing import Optional +from typing import Type __author__ = "Jens Finkhaeuser" __copyright__ = "Copyright (c) 2018,2019 Jens Finkhaeuser" @@ -8,7 +10,11 @@ # Raise the given exception class from the caught exception, preserving # stack trace and message as much as possible. -def raise_from(klass, from_value, extra_message=None): +def raise_from( + klass: type[BaseException], + from_value: BaseException | None, + extra_message: str | None = None, +) -> None: try: if from_value is None: if extra_message is not None: @@ -23,4 +29,4 @@ def raise_from(klass, from_value, extra_message=None): args.append(extra_message) raise klass(*args) from from_value finally: - klass = None + klass = None # type: ignore[assignment] diff --git a/prance/util/formats.py b/prance/util/formats.py index 4817a3d..1c40bd3 100644 --- a/prance/util/formats.py +++ b/prance/util/formats.py @@ -1,4 +1,10 @@ """This submodule contains file format related utility code for Prance.""" +from collections.abc import Callable +from typing import Dict +from typing import Optional +from typing import Tuple + +from prance.util.path import JsonValue __author__ = "Jens Finkhaeuser" __copyright__ = "Copyright (c) 2016-2021 Jens Finkhaeuser" @@ -10,7 +16,9 @@ class ParseError(ValueError): pass # pragma: nocover -def __format_preferences(filename, content_type): # noqa: N802 +def __format_preferences( + filename: str | None, content_type: str | None +) -> tuple[str, ...]: # noqa: N802 """ Detect the format based on file name and content type. @@ -26,7 +34,7 @@ def __format_preferences(filename, content_type): # noqa: N802 # 4) If both are present, prefer the content type. # 5) use a heuristic either way to catch bad content types, file names, # etc. The selection process above is just the most likely match! - best = None + best: str | None = None if filename and not content_type: from os.path import splitext @@ -61,27 +69,27 @@ def __format_preferences(filename, content_type): # noqa: N802 # Basic parse functions -def __parse_yaml(spec_str): # noqa: N802 - from ruamel.yaml import YAML, parser +def __parse_yaml(spec_str: str) -> JsonValue: # noqa: N802 + from ruamel.yaml import YAML, parser # type: ignore[import-not-found] try: yaml = YAML(typ="safe") - return yaml.load(str(spec_str)) + return yaml.load(str(spec_str)) # type: ignore[no-any-return] except parser.ParserError as err: raise ParseError(str(err)) -def __parse_json(spec_str): # noqa: N802 +def __parse_json(spec_str: str) -> JsonValue: # noqa: N802 import json try: - return json.loads(str(spec_str)) + return json.loads(str(spec_str)) # type: ignore[no-any-return] except ValueError as err: raise ParseError(str(err)) # Basic serialization functions -def __serialize_yaml(specs): # noqa: N802 +def __serialize_yaml(specs: JsonValue) -> str: # noqa: N802 import io from ruamel.yaml import YAML @@ -91,7 +99,7 @@ def __serialize_yaml(specs): # noqa: N802 return buf.getvalue().decode("UTF-8") -def __serialize_json(specs): # noqa: N802 +def __serialize_json(specs: JsonValue) -> str: # noqa: N802 # The default encoding is utf-8, no need to specify it. But we need to switch # off ensure_ascii, otherwise we do not get a unicode string back. import json @@ -102,29 +110,29 @@ def __serialize_json(specs): # noqa: N802 # Map file name extensions to parse/serialize functions -__EXT_TO_FORMAT = { +__EXT_TO_FORMAT: dict[tuple[str, ...], str] = { (".yaml", ".yml"): "YAML", (".json", ".js"): "JSON", } -__MIME_TO_FORMAT = { +__MIME_TO_FORMAT: dict[tuple[str, ...], str] = { ("application/json", "application/javascript"): "JSON", ("application/yaml", "text/yaml"): "YAML", } -__FORMAT_TO_PARSER = { +__FORMAT_TO_PARSER: dict[str, Callable[[str], JsonValue]] = { "YAML": __parse_yaml, "JSON": __parse_json, } -__FORMAT_TO_SERIALIZER = { +__FORMAT_TO_SERIALIZER: dict[str, Callable[[JsonValue], str]] = { "YAML": __serialize_yaml, "JSON": __serialize_json, } -def format_info(format_name): +def format_info(format_name: str) -> tuple[str | None, str | None]: """ Return content type and extension for a supported format. @@ -137,12 +145,12 @@ def format_info(format_name): """ format_name = format_name.upper() - content_type = None + content_type: str | None = None for content_types, name in __MIME_TO_FORMAT.items(): if name == format_name: content_type = content_types[0] - extension = None + extension: str | None = None for extensions, name in __EXT_TO_FORMAT.items(): if name == format_name: extension = extensions[0] @@ -150,7 +158,9 @@ def format_info(format_name): return content_type, extension -def parse_spec_details(spec_str, filename=None, **kwargs): +def parse_spec_details( + spec_str: str, filename: str | None = None, **kwargs: str | None +) -> tuple[JsonValue, str | None, str | None]: """ Return a parsed dict of the given spec string. @@ -169,8 +179,8 @@ def parse_spec_details(spec_str, filename=None, **kwargs): :raises ParseError: when parsing fails. """ # Fetch optional content type & determine formats - content_type = kwargs.get("content_type", None) - formats = __format_preferences(filename, content_type) + content_type_str: str | None = kwargs.get("content_type", None) + formats = __format_preferences(filename, content_type_str) # Try parsing each format in order for f in formats: @@ -186,7 +196,9 @@ def parse_spec_details(spec_str, filename=None, **kwargs): raise ParseError("Could not detect format of spec string!") -def parse_spec(spec_str, filename=None, **kwargs): +def parse_spec( + spec_str: str, filename: str | None = None, **kwargs: str | None +) -> JsonValue: """ Return a parsed dict of the given spec string. @@ -205,7 +217,9 @@ def parse_spec(spec_str, filename=None, **kwargs): return result -def serialize_spec(specs, filename=None, **kwargs): +def serialize_spec( + specs: JsonValue, filename: str | None = None, **kwargs: str | None +) -> str: """ Return a serialized version of the given spec. @@ -221,8 +235,8 @@ def serialize_spec(specs, filename=None, **kwargs): :rtype: str """ # Fetch optional content type & determine formats - content_type = kwargs.get("content_type", None) - formats = __format_preferences(filename, content_type) + content_type_str: str | None = kwargs.get("content_type", None) + formats = __format_preferences(filename, content_type_str) # Instead of trying to parse various formats, we only serialize to the first # one in the list - nothing else makes much sense. diff --git a/prance/util/fs.py b/prance/util/fs.py index 3886616..a12b18b 100644 --- a/prance/util/fs.py +++ b/prance/util/fs.py @@ -1,4 +1,5 @@ """This submodule contains file system utilities for Prance.""" +from typing import Optional __author__ = "Jens Finkhaeuser" __copyright__ = "Copyright (c) 2016-2019 Jens Finkhaeuser" @@ -35,7 +36,7 @@ """ -def is_pathname_valid(pathname): +def is_pathname_valid(pathname: str) -> bool: """ Test whether a path name is valid. @@ -121,7 +122,7 @@ def is_pathname_valid(pathname): # Did we mention this should be shipped with Python already? -def from_posix(fname): +def from_posix(fname: str) -> str: """ Convert a path from posix-like, to the platform format. @@ -138,7 +139,7 @@ def from_posix(fname): return fname -def to_posix(fname): +def to_posix(fname: str) -> str: """ Convert a path to posix-like format. @@ -157,7 +158,7 @@ def to_posix(fname): return fname -def abspath(filename, relative_to=None): +def abspath(filename: str, relative_to: str | None = None) -> str: """ Return the absolute path of a file relative to a reference file. @@ -185,7 +186,7 @@ def abspath(filename, relative_to=None): return to_posix(fname) -def canonical_filename(filename): +def canonical_filename(filename: str) -> str: """ Return the canonical version of a file name. @@ -213,7 +214,7 @@ def canonical_filename(filename): return path -def detect_encoding(filename, default_to_utf8=True, **kwargs): +def detect_encoding(filename: str, default_to_utf8: bool = True, **kwargs: bool) -> str: """ Detect the named file's character encoding. @@ -254,13 +255,13 @@ def detect_encoding(filename, default_to_utf8=True, **kwargs): try: # First try ICU. ICU will report ASCII in the first 32 Bytes as # ISO-8859-1, which isn't exactly wrong, but maybe optimistic. - import icu + import icu # type: ignore[import-not-found] encoding = icu.CharsetDetector(raw).detect().getName().lower() except ImportError: # pragma: nocover # If that doesn't work, try chardet - it's not got native components, # which is a bonus in some environments, but it's not as precise. - import chardet + import chardet # type: ignore[import-not-found] encoding = chardet.detect(raw)["encoding"].lower() @@ -287,7 +288,7 @@ def detect_encoding(filename, default_to_utf8=True, **kwargs): return encoding -def read_file(filename, encoding=None): +def read_file(filename: str, encoding: str | None = None) -> str: """ Read and decode a file, taking BOMs into account. @@ -307,7 +308,7 @@ def read_file(filename, encoding=None): return handle.read() -def write_file(filename, contents, encoding=None): +def write_file(filename: str, contents: str, encoding: str | None = None) -> None: """ Write a file with the given encoding. diff --git a/prance/util/iterators.py b/prance/util/iterators.py index 8c737c0..2cb9b47 100644 --- a/prance/util/iterators.py +++ b/prance/util/iterators.py @@ -1,12 +1,31 @@ """This submodule contains specialty iterators over specs.""" +from collections.abc import Iterator +from collections.abc import Mapping +from collections.abc import Sequence +from typing import Tuple +from typing import Union __author__ = "Jens Finkhaeuser" __copyright__ = "Copyright (c) 2016-2018 Jens Finkhaeuser" __license__ = "MIT" __all__ = () +# Type alias for JSON-like values (recursive structure) +JsonValue = Union[ + Mapping[Union[str, int], "JsonValue"], # Mappings can have str or int keys + Sequence["JsonValue"], + str, + int, + float, + bool, + None, +] +PathElement = Union[str, int] -def item_iterator(value, path=()): + +def item_iterator( + value: JsonValue, path: tuple[PathElement, ...] = () +) -> Iterator[tuple[tuple[PathElement, ...], JsonValue]]: """ Return item iterator over the a nested dict- or list-like object. @@ -44,8 +63,6 @@ def item_iterator(value, path=()): # Yield the top-level object, always yield path, value - from collections.abc import Mapping, Sequence - # For dict and list like objects, we also need to yield each item # recursively. if isinstance(value, Mapping): @@ -56,7 +73,9 @@ def item_iterator(value, path=()): yield from item_iterator(item, path + (idx,)) -def reference_iterator(specs, path=()): +def reference_iterator( + specs: JsonValue, path: tuple[PathElement, ...] = () +) -> Iterator[tuple[PathElement, JsonValue, tuple[PathElement, ...]]]: """ Iterate through the given specs, returning only references. diff --git a/prance/util/path.py b/prance/util/path.py index 24e70e4..7e86b09 100644 --- a/prance/util/path.py +++ b/prance/util/path.py @@ -1,25 +1,51 @@ """This module contains code for accessing values in nested data structures.""" +from collections.abc import Mapping +from collections.abc import MutableMapping +from collections.abc import MutableSequence +from collections.abc import Sequence +from collections.abc import Sequence as AbcSequence +from typing import Optional +from typing import Tuple +from typing import Type +from typing import Union __author__ = "Jens Finkhaeuser" __copyright__ = "Copyright (c) 2018 Jens Finkhaeuser" __license__ = "MIT" __all__ = () +# Type aliases +PathElement = Union[str, int] +JsonValue = Union[ + Mapping[Union[str, int], "JsonValue"], # Mappings can have str or int keys + Sequence["JsonValue"], + str, + int, + float, + bool, + None, +] -def _json_ref_escape(path): + +def _json_ref_escape(path: PathElement) -> str: """JSON-reference escape object path.""" - path = str(path) # Could be an int, etc. - path = path.replace("~", "~0") - path = path.replace("/", "~1") - return path + path_str = str(path) # Could be an int, etc. + path_str = path_str.replace("~", "~0") + path_str = path_str.replace("/", "~1") + return path_str -def _str_path(path): +def _str_path(path: Sequence[PathElement]) -> str: """Stringify object path.""" return "/" + "/".join([_json_ref_escape(p) for p in path]) -def path_get(obj, path, defaultvalue=None, path_of_obj=()): +def path_get( + obj: JsonValue, + path: Sequence[PathElement] | None, + defaultvalue: JsonValue = None, + path_of_obj: tuple[PathElement, ...] = (), +) -> JsonValue: """ Retrieve the value from obj indicated by path. @@ -35,12 +61,10 @@ def path_get(obj, path, defaultvalue=None, path_of_obj=()): :param mixed defaultvalue: If the value at the path does not exist and this parameter is not None, it is returned. Otherwise an error is raised. """ - from collections.abc import Mapping, Sequence - # For error reporting. path_of_obj_str = _str_path(path_of_obj) - if path is not None and not isinstance(path, Sequence): + if path is not None and not isinstance(path, AbcSequence): raise TypeError(f"Path is a {type(path)}, but must be None or a Collection!") if isinstance(obj, Mapping): @@ -58,7 +82,7 @@ def path_get(obj, path, defaultvalue=None, path_of_obj=()): obj[path[0]], path[1:], defaultvalue, path_of_obj=path_of_obj + (path[0],) ) - elif isinstance(obj, Sequence): + elif isinstance(obj, AbcSequence): if path is None or len(path) < 1: return obj or defaultvalue @@ -90,7 +114,9 @@ def path_get(obj, path, defaultvalue=None, path_of_obj=()): return obj or defaultvalue -def path_set(obj, path, value, **options): +def path_set( + obj: JsonValue, path: Sequence[PathElement], value: JsonValue, **options: bool +) -> JsonValue: """ Set the value in obj indicated by path. @@ -108,7 +134,9 @@ def path_set(obj, path, value, **options): # Retrieve options create = options.get("create", False) - def fill_sequence(seq, index, value_index_type): + def fill_sequence( + seq: MutableSequence[JsonValue], index: int, value_index_type: type[int] | None + ) -> None: """ Fill the sequence seq with elements until index can be accessed. @@ -130,7 +158,7 @@ def fill_sequence(seq, index, value_index_type): else: seq.append({}) - def safe_idx(seq, index): + def safe_idx(seq: Sequence[PathElement], index: int) -> type[int] | None: """ Safely index a sequence. @@ -138,7 +166,7 @@ def safe_idx(seq, index): raising IndexError. """ try: - return type(seq[index]) + return type(seq[index]) # type: ignore[return-value] except IndexError: return None @@ -146,9 +174,7 @@ def safe_idx(seq, index): # print('path', path) # print('value', value) - from collections.abc import Sequence, MutableSequence, Mapping, MutableMapping - - if path is not None and not isinstance(path, Sequence): + if path is not None and not isinstance(path, AbcSequence): raise TypeError(f"Path is a {type(path)}, but must be None or a Collection!") if len(path) < 1: @@ -177,7 +203,7 @@ def safe_idx(seq, index): return obj - elif isinstance(obj, Sequence): + elif isinstance(obj, AbcSequence): idx = path[0] # If we don't have a mutable sequence, we should raise a TypeError diff --git a/prance/util/resolver.py b/prance/util/resolver.py index 92bb3ca..8d2aa5b 100644 --- a/prance/util/resolver.py +++ b/prance/util/resolver.py @@ -1,12 +1,25 @@ """This submodule contains a JSON inlining reference resolver.""" +from collections.abc import Callable +from collections.abc import Iterator +from collections.abc import MutableMapping +from typing import Any +from typing import cast +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union +from urllib.parse import ParseResult + +import prance.util.url as _url +from prance.util.path import JsonValue +from prance.util.path import PathElement __author__ = "Jens Finkhaeuser" __copyright__ = "Copyright (c) 2016-2018 Jens Finkhaeuser" __license__ = "MIT" __all__ = () -import prance.util.url as _url - #: Resolve internal references RESOLVE_INTERNAL = 2**1 #: Resolve references to HTTP external files. @@ -23,23 +36,29 @@ RESOLVE_ALL = RESOLVE_INTERNAL | RESOLVE_HTTP | RESOLVE_FILES -def default_reclimit_handler(limit, parsed_url, recursions=()): +def default_reclimit_handler( + limit: int, + parsed_url: ParseResult, + recursions: tuple[tuple[str, tuple[PathElement, ...]], ...] = (), +) -> None: """Raise prance.util.url.ResolutionError.""" - path = [] + path: list[str] = [] for rc in recursions: - path.append("{}#/{}".format(rc[0], "/".join(rc[1]))) - path = "\n".join(path) + path.append("{}#/{}".format(rc[0], "/".join(str(p) for p in rc[1]))) + path_str = "\n".join(path) raise _url.ResolutionError( "Recursion reached limit of %d trying to " - 'resolve "%s"!\n%s' % (limit, parsed_url.geturl(), path) + 'resolve "%s"!\n%s' % (limit, parsed_url.geturl(), path_str) ) class RefResolver: """Resolve JSON pointers/references in a spec by inlining.""" - def __init__(self, specs, url=None, **options): + def __init__( + self, specs: JsonValue, url: str | ParseResult | None = None, **options: Any + ) -> None: """ Construct a JSON reference resolver. @@ -80,19 +99,23 @@ def __init__(self, specs, url=None, **options): """ import copy - self.specs = copy.deepcopy(specs) - self.url = url + self.specs: JsonValue = copy.deepcopy(specs) + self.url: str | ParseResult | None = url - self.__reclimit = options.get("recursion_limit", 1) - self.__reclimit_handler = options.get( - "recursion_limit_handler", default_reclimit_handler + self.__reclimit: int = options.get("recursion_limit", 1) + self.__reclimit_handler: Callable[ + [int, ParseResult, tuple[tuple[str, tuple[PathElement, ...]], ...]], Any + ] = options.get("recursion_limit_handler", default_reclimit_handler) + self.__reference_cache: dict[str | tuple[str, bool], JsonValue] = options.get( + "reference_cache", {} ) - self.__reference_cache = options.get("reference_cache", {}) - self.__resolve_types = options.get("resolve_types", RESOLVE_ALL) - self.__resolve_method = options.get("resolve_method", TRANSLATE_DEFAULT) - self.__encoding = options.get("encoding", None) - self.__strict = options.get("strict", True) + self.__resolve_types: int = options.get("resolve_types", RESOLVE_ALL) + self.__resolve_method: int = options.get("resolve_method", TRANSLATE_DEFAULT) + self.__encoding: str | None = options.get("encoding", None) + self.__strict: bool = options.get("strict", True) + self.parsed_url: ParseResult | None + self._url_key: tuple[str, bool] | None if self.url: self.parsed_url = _url.absurl(self.url) self._url_key = (_url.urlresource(self.parsed_url), self.__strict) @@ -105,23 +128,34 @@ def __init__(self, specs, url=None, **options): else: self.parsed_url = self._url_key = None - self.__soft_dereference_objs = {} + self.__soft_dereference_objs: dict[str, JsonValue] = {} - def resolve_references(self): + def resolve_references(self) -> None: """Resolve JSON pointers/references in the spec.""" self.specs = self._resolve_partial(self.parsed_url, self.specs, ()) # If there are any objects collected when using TRANSLATE_EXTERNAL, add # them to components/schemas if self.__soft_dereference_objs: - if "components" not in self.specs: - self.specs["components"] = {} - if "schemas" not in self.specs["components"]: - self.specs["components"].update({"schemas": {}}) - - self.specs["components"]["schemas"].update(self.__soft_dereference_objs) - - def _dereferencing_iterator(self, base_url, partial, path, recursions): + # Type narrow specs to MutableMapping for safe indexing + if isinstance(self.specs, MutableMapping): + if "components" not in self.specs: + self.specs["components"] = {} + components = self.specs["components"] + if isinstance(components, MutableMapping): + if "schemas" not in components: + components.update({"schemas": {}}) + schemas = components["schemas"] + if isinstance(schemas, MutableMapping): + schemas.update(self.__soft_dereference_objs) + + def _dereferencing_iterator( + self, + base_url: ParseResult | None, + partial: JsonValue, + path: tuple[PathElement, ...], + recursions: tuple[tuple[str, tuple[PathElement, ...]], ...], + ) -> Iterator[tuple[tuple[PathElement, ...], JsonValue]]: """ Iterate over a partial spec, dereferencing all references within. @@ -135,11 +169,15 @@ def _dereferencing_iterator(self, base_url, partial, path, recursions): from .iterators import reference_iterator for _, refstring, item_path in reference_iterator(partial): + # Type narrow refstring to str for split_url_reference + if not isinstance(refstring, str): + continue + # Split the reference string into parsed URL and object path ref_url, obj_path = _url.split_url_reference(base_url, refstring) translate = (self.__resolve_method == TRANSLATE_EXTERNAL) and ( - self.parsed_url.path != ref_url.path + self.parsed_url is not None and self.parsed_url.path != ref_url.path ) if self._skip_reference(base_url, ref_url): @@ -175,23 +213,29 @@ def _dereferencing_iterator(self, base_url, partial, path, recursions): else: yield full_path, ref_value - def _collect_soft_refs(self, ref_url, item_path, value): + def _collect_soft_refs( + self, ref_url: ParseResult, item_path: list[PathElement], value: JsonValue + ) -> str: """ Return a portion of the dereferenced url for TRANSLATE_EXTERNAL mode. format - ref-url_obj-path """ - dref_url = ref_url.path.split("/")[-1] + "_" + "_".join(item_path[1:]) + dref_url = ( + ref_url.path.split("/")[-1] + "_" + "_".join(str(p) for p in item_path[1:]) + ) self.__soft_dereference_objs[dref_url] = value return dref_url - def _skip_reference(self, base_url, ref_url): + def _skip_reference( + self, base_url: ParseResult | None, ref_url: ParseResult + ) -> bool: """Return whether the URL should not be dereferenced.""" if ref_url.scheme.startswith("http"): return (self.__resolve_types & RESOLVE_HTTP) == 0 elif ref_url.scheme == "file" or ref_url.scheme == "python": # Internal references - if base_url.path == ref_url.path: + if base_url is not None and base_url.path == ref_url.path: return (self.__resolve_types & RESOLVE_INTERNAL) == 0 # Local files return (self.__resolve_types & RESOLVE_FILES) == 0 @@ -204,7 +248,12 @@ def _skip_reference(self, base_url, ref_url): ) ) - def _dereference(self, ref_url, obj_path, recursions): + def _dereference( + self, + ref_url: ParseResult, + obj_path: list[PathElement], + recursions: tuple[tuple[str, tuple[PathElement, ...]], ...], + ) -> JsonValue: """ Dereference the URL and object path. @@ -246,7 +295,12 @@ def _dereference(self, ref_url, obj_path, recursions): # That's it! return value - def _resolve_partial(self, base_url, partial, recursions): + def _resolve_partial( + self, + base_url: ParseResult | None, + partial: JsonValue, + recursions: tuple[tuple[str, tuple[PathElement, ...]], ...], + ) -> JsonValue: """ Resolve a (partial) spec's references. diff --git a/prance/util/translator.py b/prance/util/translator.py index 1dc3223..e195b26 100644 --- a/prance/util/translator.py +++ b/prance/util/translator.py @@ -1,23 +1,33 @@ """This submodule contains a JSON reference translator.""" +from collections.abc import Iterator +from collections.abc import MutableMapping +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union +from urllib.parse import ParseResult + +import prance.util.url as _url +from prance.util.path import JsonValue +from prance.util.path import PathElement __author__ = "Štěpán Tomsa" __copyright__ = "Copyright © 2021 Štěpán Tomsa" __license__ = "MIT" __all__ = () -import prance.util.url as _url - -def _reference_key(ref_url, item_path): +def _reference_key(ref_url: ParseResult, item_path: list[PathElement]) -> str: """ Return a portion of the dereferenced URL. format - ref-url_obj-path """ - return ref_url.path.split("/")[-1] + "_" + "_".join(item_path[1:]) + return ref_url.path.split("/")[-1] + "_" + "_".join(str(p) for p in item_path[1:]) -def _local_ref(path): +def _local_ref(path: list[str]) -> dict[str, str]: url = "#/" + "/".join(path) return {"$ref": url} @@ -32,7 +42,7 @@ class _RefTranslator: object locations. """ - def __init__(self, specs, url): + def __init__(self, specs: JsonValue, url: str | None) -> None: """ Construct a JSON reference translator. @@ -47,15 +57,16 @@ def __init__(self, specs, url): """ import copy - self.specs = copy.deepcopy(specs) + self.specs: JsonValue = copy.deepcopy(specs) - self.__strict = True - self.__reference_cache = {} - self.__collected_references = {} + self.__strict: bool = True + self.__reference_cache: dict[tuple[str, bool], JsonValue] = {} + self.__collected_references: dict[str, JsonValue | None] = {} + self.url: ParseResult | None if url: self.url = _url.absurl(url) - url_key = (_url.urlresource(self.url), self.__strict) + url_key: tuple[str, bool] = (_url.urlresource(self.url), self.__strict) # If we have a url, we want to add ourselves to the reference cache # - that creates a reference loop, but prevents child resolvers from @@ -64,7 +75,7 @@ def __init__(self, specs, url): else: self.url = None - def translate_references(self): + def translate_references(self) -> None: """ Iterate over the specification document, performing the translation. @@ -72,18 +83,29 @@ def translate_references(self): external files to the /components/schemas object in the root document and translating the references to the new location. """ + # url must be a ParseResult for _translate_partial + if self.url is None: + return + self.specs = self._translate_partial(self.url, self.specs) # Add collected references to the root document. if self.__collected_references: - if "components" not in self.specs: - self.specs["components"] = {} - if "schemas" not in self.specs["components"]: - self.specs["components"].update({"schemas": {}}) - - self.specs["components"]["schemas"].update(self.__collected_references) - - def _dereference(self, ref_url, obj_path): + # Type narrow specs to MutableMapping for safe indexing + if isinstance(self.specs, MutableMapping): + if "components" not in self.specs: + self.specs["components"] = {} + components = self.specs["components"] + if isinstance(components, MutableMapping): + if "schemas" not in components: + components.update({"schemas": {}}) + schemas = components["schemas"] + if isinstance(schemas, MutableMapping): + schemas.update(self.__collected_references) + + def _dereference( + self, ref_url: ParseResult, obj_path: list[PathElement] + ) -> JsonValue: """ Dereference the URL and object path. @@ -97,7 +119,7 @@ def _dereference(self, ref_url, obj_path): """ # In order to start dereferencing anything in the referenced URL, we have # to read and parse it, of course. - contents = _url.fetch_url(ref_url, self.__reference_cache, strict=self.__strict) + contents = _url.fetch_url(ref_url, self.__reference_cache, strict=self.__strict) # type: ignore[arg-type] # In this inner parser's specification, we can now look for the referenced # object. @@ -123,7 +145,9 @@ def _dereference(self, ref_url, obj_path): # That's it! return value - def _translate_partial(self, base_url, partial): + def _translate_partial( + self, base_url: ParseResult, partial: JsonValue + ) -> JsonValue: changes = dict(tuple(self._translating_iterator(base_url, partial, ()))) paths = sorted(changes.keys(), key=len) @@ -131,7 +155,7 @@ def _translate_partial(self, base_url, partial): from prance.util.path import path_set for path in paths: - value = changes[path] + value: JsonValue = changes[path] # type: ignore[assignment] if len(path) == 0: partial = value else: @@ -139,14 +163,20 @@ def _translate_partial(self, base_url, partial): return partial - def _translating_iterator(self, base_url, partial, path): + def _translating_iterator( + self, base_url: ParseResult, partial: JsonValue, path: tuple[PathElement, ...] + ) -> Iterator[tuple[tuple[PathElement, ...], dict[str, str]]]: from prance.util.iterators import reference_iterator for _, ref_string, item_path in reference_iterator(partial): + # Type narrow ref_string to str for split_url_reference + if not isinstance(ref_string, str): + continue + ref_url, obj_path = _url.split_url_reference(base_url, ref_string) full_path = path + item_path - if ref_url.path == self.url.path: + if self.url is None or ref_url.path == self.url.path: # Reference to the root document. ref_path = obj_path else: @@ -158,5 +188,7 @@ def _translating_iterator(self, base_url, partial, path): self.__collected_references[ref_key] = ref_value ref_path = ["components", "schemas", ref_key] - ref_obj = _local_ref(ref_path) + # Convert ref_path to List[str] for _local_ref + ref_path_str: list[str] = [str(p) for p in ref_path] + ref_obj = _local_ref(ref_path_str) yield full_path, ref_obj diff --git a/prance/util/url.py b/prance/util/url.py index a9c18e3..77cde30 100644 --- a/prance/util/url.py +++ b/prance/util/url.py @@ -1,4 +1,16 @@ """This submodule contains code for fetching/parsing URLs.""" +from collections.abc import Mapping +from typing import cast +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union +from urllib import parse +from urllib.parse import ParseResult + +from prance.util.path import JsonValue +from prance.util.path import PathElement __author__ = "Jens Finkhaeuser" __copyright__ = "Copyright (c) 2016-2018 Jens Finkhaeuser" @@ -6,14 +18,11 @@ __all__ = () -from urllib import parse - - class ResolutionError(LookupError): pass -def urlresource(url): +def urlresource(url: ParseResult) -> str: """ Return the resource part of a parsed URL. @@ -24,11 +33,15 @@ def urlresource(url): :return: The resource part of the URL :rtype: str """ - res_list = list(url)[0:3] + [None, None, None] - return parse.ParseResult(*res_list).geturl() + res_list: list[str | None] = list(url)[0:3] + [None, None, None] + return parse.ParseResult( + *cast(tuple[str, str, str, str, str, str], res_list) + ).geturl() -def absurl(url, relative_to=None): +def absurl( + url: str | ParseResult, relative_to: str | ParseResult | None = None +) -> ParseResult: """ Turn relative file URLs into absolute file URLs. @@ -46,8 +59,10 @@ def absurl(url, relative_to=None): :rtype: tuple """ # Parse input URL, if necessary - parsed = url - if not isinstance(parsed, tuple): + parsed: ParseResult + if isinstance(url, tuple): + parsed = url + else: from .fs import is_pathname_valid if is_pathname_valid(url): @@ -66,15 +81,18 @@ def absurl(url, relative_to=None): return parsed # Parse up the reference URL - reference = relative_to - if reference and not isinstance(reference, tuple): - from .fs import is_pathname_valid + reference: ParseResult | None = None + if relative_to: + if isinstance(relative_to, tuple): + reference = relative_to + else: + from .fs import is_pathname_valid - if is_pathname_valid(reference): - from . import fs + if is_pathname_valid(relative_to): + from . import fs - reference = fs.to_posix(reference) - reference = parse.urlparse(reference) + relative_to = fs.to_posix(relative_to) + reference = parse.urlparse(relative_to) # If the input URL has no path, we assume only its fragment matters. # That is, we'll have to set the fragment of the reference URL to that @@ -82,7 +100,7 @@ def absurl(url, relative_to=None): import os.path from .fs import from_posix, abspath - result_list = None + result_list: list[str] | None = None if not parsed.path: if not reference or not reference.path: raise ResolutionError( @@ -116,7 +134,9 @@ def absurl(url, relative_to=None): return result -def split_url_reference(base_url, reference): +def split_url_reference( + base_url: ParseResult | None, reference: str +) -> tuple[ParseResult, list[PathElement]]: """ Return a normalized, parsed URL and object path. @@ -138,17 +158,21 @@ def split_url_reference(base_url, reference): obj_path = obj_path[1:] # Normalize the object path by substituting ~1 and ~0 respectively. - def _normalize(path): + def _normalize(path: str) -> str: path = path.replace("~1", "/") path = path.replace("~0", "~") return path - obj_path = [_normalize(p) for p in obj_path] + obj_path_normalized: list[PathElement] = [_normalize(p) for p in obj_path] - return parsed_url, obj_path + return parsed_url, obj_path_normalized -def fetch_url_text(url, cache={}, encoding=None): +def fetch_url_text( + url: ParseResult, + cache: dict[str, tuple[str, str | None]] | None = None, + encoding: str | None = None, +) -> tuple[str, str | None]: """ Fetch the URL. @@ -167,6 +191,9 @@ def fetch_url_text(url, cache={}, encoding=None): :return: The resource text of the URL, and the content type. :rtype: tuple """ + if cache is None: + cache = {} + url_key = "text_" + urlresource(url) entry = cache.get(url_key, None) if entry is not None: @@ -174,8 +201,8 @@ def fetch_url_text(url, cache={}, encoding=None): # Fetch contents according to scheme. We assume requests can handle all the # non-file schemes, or throw otherwise. - content = None - content_type = None + content: str + content_type: str | None = None if url.scheme in (None, "", "file"): from .fs import read_file, from_posix @@ -194,13 +221,13 @@ def fetch_url_text(url, cache={}, encoding=None): from importlib.resources import files - path = files(package).joinpath(path) + path_traversable = files(package).joinpath(path) from .fs import read_file, from_posix - content = read_file(from_posix(path), encoding) + content = read_file(from_posix(str(path_traversable)), encoding) else: - import requests + import requests # type: ignore[import-untyped] response = requests.get(url.geturl()) if not response.ok: # pragma: nocover @@ -215,7 +242,12 @@ def fetch_url_text(url, cache={}, encoding=None): return content, content_type -def fetch_url(url, cache={}, encoding=None, strict=True): +def fetch_url( + url: ParseResult, + cache: dict[str | tuple[str, bool], JsonValue] | None = None, + encoding: str | None = None, + strict: bool = True, +) -> JsonValue: """ Fetch the URL and parse the contents. @@ -231,13 +263,23 @@ def fetch_url(url, cache={}, encoding=None, strict=True): :rtype: dict """ # Return from cache, if parsed result is already present. - url_key = (urlresource(url), strict) - entry = cache.get(url_key, None) + if cache is None: + cache = {} + + url_key_tuple: tuple[str, bool] = (urlresource(url), strict) + entry = cache.get(url_key_tuple, None) if entry is not None: - return entry.copy() + if isinstance(entry, Mapping): + return entry.copy() # type: ignore[no-any-return, attr-defined] + return entry # Fetch URL text - content, content_type = fetch_url_text(url, cache, encoding=encoding) + text_cache: dict[str, tuple[str, str | None]] = {} + for key, value in cache.items(): + if isinstance(key, str) and isinstance(value, tuple): + text_cache[key] = value + content, content_type = fetch_url_text(url, text_cache, encoding=encoding) + cache.update(text_cache) # Parse the result from .formats import parse_spec @@ -246,10 +288,14 @@ def fetch_url(url, cache={}, encoding=None, strict=True): # Perform some sanitization in lenient mode. if not strict: + from collections.abc import MutableMapping from . import stringify_keys - result = stringify_keys(result) + if isinstance(result, MutableMapping): + result = stringify_keys(result) # Cache and return result - cache[url_key] = result - return result.copy() + cache[url_key_tuple] = result + if isinstance(result, Mapping): + return result.copy() # type: ignore[no-any-return, attr-defined] + return result diff --git a/pyproject.toml b/pyproject.toml index 9f1cc0d..86f75ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,3 +83,23 @@ package = "prance" package_dir = "." filename = "CHANGES.rst" directory = "changelog.d" + +[tool.mypy] +python_version = "3.10" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +disallow_untyped_calls = true +disallow_incomplete_defs = true +check_untyped_defs = true +disallow_untyped_decorators = true +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = true +strict_equality = true +show_error_codes = true + +[[tool.mypy.overrides]] +module = "prance.cli" +disallow_untyped_decorators = false diff --git a/tests/test_util_fs.py b/tests/test_util_fs.py index 07b05f8..8326877 100644 --- a/tests/test_util_fs.py +++ b/tests/test_util_fs.py @@ -209,5 +209,5 @@ def test_valid_pathname(): assert True == is_pathname_valid("foo") assert False == is_pathname_valid(123) - # Can't accept too long components - assert False == is_pathname_valid("a" * 256) + # Can't accept too long components (use 300 to ensure it fails on all platforms) + assert False == is_pathname_valid("a" * 300)