Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 65 additions & 40 deletions prance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,21 @@
Included is a BaseParser that reads and validates swagger specs, and a
ResolvingParser that additionally resolves any $ref references.
"""
import sys
from typing import Any
from typing import Dict
from typing import Optional
from typing import Union
from urllib.parse import ParseResult

from packaging.version import Version # type: ignore[import-not-found]

from prance.util.path import JsonValue

__author__ = "Jens Finkhaeuser"
__copyright__ = "Copyright (c) 2016-2021 Jens Finkhaeuser"
__license__ = "MIT"
__all__ = ("util", "mixins", "cli", "convert")
import sys

from packaging.version import Version

try:
from prance._version import version as __version__
Expand Down Expand Up @@ -55,7 +62,13 @@ class BaseParser(mixins.YAMLMixin, mixins.JSONMixin):
SPEC_VERSION_2_PREFIX = "Swagger/OpenAPI"
SPEC_VERSION_3_PREFIX = "OpenAPI"

def __init__(self, url=None, spec_string=None, lazy=False, **kwargs):
def __init__(
self,
url: str | None = None,
spec_string: str | None = None,
lazy: bool = False,
**kwargs: Any,
) -> None:
"""
Load, parse and validate specs.

Expand All @@ -82,32 +95,34 @@ def __init__(self, url=None, spec_string=None, lazy=False, **kwargs):
)

# Keep the parameters around for later use
self.url = None
self.url: ParseResult
if url:
from .util.url import absurl
from .util.fs import abspath
import os

self.url = absurl(url, abspath(os.getcwd()))
else:
self.url = _PLACEHOLDER_URL
from urllib.parse import urlparse

self.url = urlparse(_PLACEHOLDER_URL)

self._spec_string = spec_string
self._spec_string: str | None = spec_string

# Initialize variables we're filling later
self.specification = None
self.version = None
self.version_name = None
self.version_parsed = ()
self.valid = False
self.specification: JsonValue | None = None
self.version: str | None = None
self.version_name: str | None = None
self.version_parsed: tuple = ()
self.valid: bool = False

# Add kw args as options
self.options = kwargs
self.options: dict[str, Any] = kwargs

# Verify backend
from .util import default_validation_backend

self.backend = self.options.get("backend", default_validation_backend())
self.backend: str = self.options.get("backend", default_validation_backend())
if self.backend not in BaseParser.BACKENDS.keys():
raise ValueError(
f"Backend may only be one of {BaseParser.BACKENDS.keys()}!"
Expand All @@ -117,7 +132,7 @@ def __init__(self, url=None, spec_string=None, lazy=False, **kwargs):
if not lazy:
self.parse()

def parse(self): # noqa: F811
def parse(self) -> None: # noqa: F811
"""
When the BaseParser was lazily created, load and parse now.

Expand All @@ -128,7 +143,7 @@ def parse(self): # noqa: F811
strict = self.options.get("strict", True)

# If we have a file name, we need to read that in.
if self.url and self.url != _PLACEHOLDER_URL:
if self.url and self.url.geturl() != _PLACEHOLDER_URL:
from .util.url import fetch_url

encoding = self.options.get("encoding", None)
Expand All @@ -138,7 +153,7 @@ def parse(self): # noqa: F811
if self._spec_string:
from .util.formats import parse_spec

self.specification = parse_spec(self._spec_string, self.url)
self.specification = parse_spec(self._spec_string, self.url.path)

# If we have a parsed spec, convert it to JSON. Then we can validate
# the JSON. At this point, we *require* a parsed specification to exist,
Expand All @@ -147,7 +162,7 @@ def parse(self): # noqa: F811

self._validate()

def _validate(self):
def _validate(self) -> None:
# Ensure specification is a mapping
from collections.abc import Mapping

Expand All @@ -159,18 +174,22 @@ def _validate(self):

# Fetch the spec version. Note that this is the spec version the spec
# *claims* to be; we later set the one we actually could validate as.
spec_version = None
spec_version: str | None = None
if spec_version is None:
spec_version = self.specification.get("openapi", None)
version_val = self.specification.get("openapi", None)
if isinstance(version_val, str):
spec_version = version_val
if spec_version is None:
spec_version = self.specification.get("swagger", None)
version_val = self.specification.get("swagger", None)
if isinstance(version_val, str):
spec_version = version_val
if spec_version is None:
raise ValidationError(
"Could not determine specification schema " "version!"
)

# Try parsing the spec version, examine the first component.
import packaging.version
import packaging.version # type: ignore[import-not-found]

parsed = packaging.version.parse(spec_version)
if parsed.major not in versions:
Expand All @@ -187,7 +206,7 @@ def _validate(self):
validator(parsed)
self.valid = True

def __set_version(self, prefix, version: Version):
def __set_version(self, prefix: str, version: Version) -> None:
self.version_name = prefix
self.version_parsed = version.release

Expand All @@ -196,12 +215,12 @@ def __set_version(self, prefix, version: Version):
stringified = "%d.%d" % (version.major, version.minor)
self.version = f"{self.version_name} {stringified}"

def _validate_flex(self, spec_version: Version): # pragma: nocover
def _validate_flex(self, spec_version: Version) -> None: # pragma: nocover
# Set the version independently of whether validation succeeds
self.__set_version(BaseParser.SPEC_VERSION_2_PREFIX, spec_version)

from flex.exceptions import ValidationError as JSEValidationError
from flex.core import parse as validate
from flex.exceptions import ValidationError as JSEValidationError # type: ignore[import-not-found]
from flex.core import parse as validate # type: ignore[import-not-found]

try:
validate(self.specification)
Expand All @@ -212,12 +231,12 @@ def _validate_flex(self, spec_version: Version): # pragma: nocover

def _validate_swagger_spec_validator(
self, spec_version: Version
): # pragma: nocover
) -> None: # pragma: nocover
# Set the version independently of whether validation succeeds
self.__set_version(BaseParser.SPEC_VERSION_2_PREFIX, spec_version)

from swagger_spec_validator.common import SwaggerValidationError as SSVErr
from swagger_spec_validator.validator20 import validate_spec
from swagger_spec_validator.common import SwaggerValidationError as SSVErr # type: ignore[import-not-found]
from swagger_spec_validator.validator20 import validate_spec # type: ignore[import-not-found]

try:
validate_spec(self.specification)
Expand All @@ -228,10 +247,10 @@ def _validate_swagger_spec_validator(

def _validate_openapi_spec_validator(
self, spec_version: Version
): # pragma: nocover
from openapi_spec_validator import validate
from jsonschema.exceptions import ValidationError as JSEValidationError
from referencing.exceptions import Unresolvable
) -> None: # pragma: nocover
from openapi_spec_validator import validate # type: ignore[import-not-found]
from jsonschema.exceptions import ValidationError as JSEValidationError # type: ignore[import-untyped]
from referencing.exceptions import Unresolvable # type: ignore[import-not-found]

# Validate according to detected version. Unsupported versions are
# already caught outside of this function.
Expand All @@ -253,7 +272,7 @@ def _validate_openapi_spec_validator(
except Unresolvable as ref_unres:
raise_from(ValidationError, ref_unres)

def _strict_warning(self):
def _strict_warning(self) -> str:
"""Return a warning if strict mode is off."""
if self.options.get("strict", True):
return (
Expand All @@ -269,7 +288,13 @@ def _strict_warning(self):
class ResolvingParser(BaseParser):
"""The ResolvingParser extends BaseParser with resolving references by inlining."""

def __init__(self, url=None, spec_string=None, lazy=False, **kwargs):
def __init__(
self,
url: str | None = None,
spec_string: str | None = None,
lazy: bool = False,
**kwargs: Any,
) -> None:
"""
See :py:class:`BaseParser`.

Expand All @@ -280,11 +305,11 @@ def __init__(self, url=None, spec_string=None, lazy=False, **kwargs):
Additional parameters, see :py::class:`util.RefResolver`.
"""
# Create a reference cache
self.__reference_cache = {}
self.__reference_cache: dict[str | tuple, JsonValue] = {}

BaseParser.__init__(self, url=url, spec_string=spec_string, lazy=lazy, **kwargs)

def _validate(self):
def _validate(self) -> None:
# We have a problem with the BaseParser's validate function: the
# jsonschema implementation underlying it does not accept relative
# path references, but the Swagger specs allow them:
Expand All @@ -300,7 +325,7 @@ def _validate(self):
"resolve_method",
"strict",
)
forward_args = {
forward_args: dict[str, Any] = {
k: v for (k, v) in self.options.items() if k in forward_arg_names
}
resolver = RefResolver(
Expand All @@ -318,10 +343,10 @@ def _validate(self):

# Underscored to allow some time for the public API to be stabilized.
class _TranslatingParser(BaseParser):
def _validate(self):
def _validate(self) -> None:
from .util.translator import _RefTranslator

translator = _RefTranslator(self.specification, self.url)
translator = _RefTranslator(self.specification, self.url.geturl())
translator.translate_references()
self.specification = translator.specs

Expand Down
45 changes: 27 additions & 18 deletions prance/cli.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
"""CLI for prance."""
from typing import Any
from typing import Optional
from typing import Tuple

import click # type: ignore[import-not-found]

import prance
from prance.util import default_validation_backend
from prance.util.path import JsonValue

__author__ = "Jens Finkhaeuser"
__copyright__ = "Copyright (c) 2016-2021 Jens Finkhaeuser"
__license__ = "MIT"
__all__ = ()


import click

import prance
from prance.util import default_validation_backend


def __write_to_file(filename, specs): # noqa: N802
def __write_to_file(filename: str, specs: JsonValue) -> None: # noqa: N802
"""
Write specs to the given filename.

Expand All @@ -24,7 +27,9 @@ def __write_to_file(filename, specs): # noqa: N802
fs.write_file(filename, contents)


def __parser_for_url(url, resolve, backend, strict, encoding): # noqa: N802
def __parser_for_url(
url: str, resolve: bool, backend: str, strict: bool, encoding: str | None
) -> tuple[prance.BaseParser, str]: # noqa: N802
"""Return a parser instance for the URL and the given parameters."""
# Try the URL
formatted = click.format_filename(url)
Expand All @@ -39,7 +44,7 @@ def __parser_for_url(url, resolve, backend, strict, encoding): # noqa: N802
url = fsurl

# Create parser to use
parser = None
parser: prance.BaseParser
if resolve:
click.echo(" -> Resolving external references.")
parser = prance.ResolvingParser(
Expand All @@ -56,7 +61,7 @@ def __parser_for_url(url, resolve, backend, strict, encoding): # noqa: N802
return parser, formatted


def __validate(parser, name): # noqa: N802
def __validate(parser: prance.BaseParser, name: str) -> None: # noqa: N802
"""Validate a spec using this parser."""
from prance.util.url import ResolutionError
from prance import ValidationError
Expand All @@ -76,14 +81,14 @@ def __validate(parser, name): # noqa: N802

@click.group()
@click.version_option(version=prance.__version__)
def cli():
def cli() -> None:
pass # pragma: no cover


class GroupWithCommandOptions(click.Group):
"""Allow application of options to group with multi command."""

def add_command(self, cmd, name=None):
def add_command(self, cmd: click.Command, name: str | None = None) -> None:
click.Group.add_command(self, cmd, name=name)

# add the group parameters to the command
Expand All @@ -94,8 +99,8 @@ def add_command(self, cmd, name=None):
cmd.invoke = self.build_command_invoke(cmd.invoke)
self.invoke_without_command = True

def build_command_invoke(self, original_invoke):
def command_invoke(ctx):
def build_command_invoke(self, original_invoke: Any) -> Any:
def command_invoke(ctx: click.Context) -> None:
"""Insert invocation of group function."""
# separate the group parameters
ctx.obj = dict(_params=dict())
Expand Down Expand Up @@ -145,7 +150,9 @@ def command_invoke(ctx):
"encoding for all files. Does not work on remote URLs.",
)
@click.pass_context
def backend_options(ctx, resolve, backend, strict, encoding):
def backend_options(
ctx: click.Context, resolve: bool, backend: str, strict: bool, encoding: str | None
) -> None:
ctx.obj["resolve"] = resolve
ctx.obj["backend"] = backend
ctx.obj["strict"] = strict
Expand All @@ -171,7 +178,9 @@ def backend_options(ctx, resolve, backend, strict, encoding):
nargs=-1,
)
@click.pass_context
def validate(ctx, output_file, urls):
def validate(
ctx: click.Context, output_file: str | None, urls: tuple[str, ...]
) -> None:
"""
Validate the given spec or specs.

Expand Down Expand Up @@ -226,7 +235,7 @@ def validate(ctx, output_file, urls):
required=False,
)
@click.pass_context
def compile(ctx, url_or_path, output_file):
def compile(ctx: click.Context, url_or_path: str, output_file: str | None) -> None:
"""
Compile the given spec, resolving references if required.

Expand Down Expand Up @@ -273,7 +282,7 @@ def compile(ctx, url_or_path, output_file):
nargs=1,
required=False,
)
def convert(url_or_path, output_file):
def convert(url_or_path: str, output_file: str | None) -> None:
"""
Convert the given spec to OpenAPI 3.x.y.

Expand Down
Loading