-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsource.py
More file actions
188 lines (148 loc) · 6.16 KB
/
source.py
File metadata and controls
188 lines (148 loc) · 6.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import json
from pathlib import Path
from typing import Any, Dict, Optional, Union, TYPE_CHECKING
import yaml
from pydantic import BaseModel, ConfigDict, Field, field_validator
from ..errors.exceptions import SpecInvalidError
if TYPE_CHECKING:
from fastapi import FastAPI
class OpenAPIInfo(BaseModel):
"""OpenAPI info object."""
title: str
version: str
description: Optional[str] = None
model_config = ConfigDict(populate_by_name=True)
class OpenAPISpec(BaseModel):
"""Represents a validated OpenAPI specification.
This class handles loading, validating, and accessing OpenAPI specifications
from various sources such as files, JSON strings, or dictionaries.
"""
openapi: str = Field(..., description="OpenAPI version string")
info: OpenAPIInfo = Field(..., description="Information about the API")
paths: Dict[str, Any] = Field(..., description="API paths")
model_config = ConfigDict(extra="allow", populate_by_name=True)
@field_validator("openapi")
@classmethod
def validate_openapi_version(cls, v: str) -> str:
"""Validate the OpenAPI version string.
Args:
v: The version string to validate.
Returns:
The validated version string.
Raises:
ValueError: If the version is not a supported OpenAPI 3.x version.
"""
if not v.startswith("3."):
raise ValueError("Only OpenAPI 3.x specifications are supported")
return v
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "OpenAPISpec":
"""Create an OpenAPISpec instance from a dictionary.
Args:
data: A dictionary representing an OpenAPI specification.
Returns:
An OpenAPISpec instance.
Raises:
SpecInvalidError: If the dictionary is not a valid OpenAPI specification.
"""
try:
return cls.model_validate(data)
except Exception as e:
raise SpecInvalidError(f"Invalid OpenAPI specification: {str(e)}", details=data, cause=e)
@classmethod
def from_json(cls, json_str: str) -> "OpenAPISpec":
"""Create an OpenAPISpec instance from a JSON string.
Args:
json_str: A JSON string representing an OpenAPI specification.
Returns:
An OpenAPISpec instance.
Raises:
SpecInvalidError: If the JSON string is not valid JSON or not a valid OpenAPI specification.
"""
try:
data = json.loads(json_str)
return cls.from_dict(data)
except json.JSONDecodeError as e:
raise SpecInvalidError(f"Invalid JSON: {str(e)}", details={"json_str": json_str[:100]}, cause=e)
@classmethod
def from_yaml(cls, yaml_str: str) -> "OpenAPISpec":
"""Create an OpenAPISpec instance from a YAML string.
Args:
yaml_str: A YAML string representing an OpenAPI specification.
Returns:
An OpenAPISpec instance.
Raises:
SpecInvalidError: If the YAML string is not valid YAML or not a valid OpenAPI specification.
"""
try:
data = yaml.safe_load(yaml_str)
return cls.from_dict(data)
except yaml.YAMLError as e:
raise SpecInvalidError(f"Invalid YAML: {str(e)}", details={"yaml_str": yaml_str[:100]}, cause=e)
@classmethod
def from_file(cls, file_path: Union[str, Path]) -> "OpenAPISpec":
"""Create an OpenAPISpec instance from a file.
The file can be either JSON or YAML, determined by the file extension.
Args:
file_path: Path to a JSON or YAML file containing an OpenAPI specification.
Returns:
An OpenAPISpec instance.
Raises:
SpecInvalidError: If the file cannot be read, is not valid JSON/YAML,
or not a valid OpenAPI specification.
"""
if isinstance(file_path, str):
file_path = Path(file_path)
try:
file_path = file_path.resolve()
content = file_path.read_text(encoding="utf-8")
# Determine parser to use based on file extension
if file_path.suffix.lower() in (".json",):
return cls.from_json(content)
elif file_path.suffix.lower() in (".yaml", ".yml"):
return cls.from_yaml(content)
else:
raise SpecInvalidError(
f"Unsupported file extension: {file_path.suffix}. Only .json, .yaml, and .yml files are supported.",
details={"file_path": str(file_path)},
)
except (OSError, IOError) as e:
raise SpecInvalidError(
f"Failed to read file: {str(e)}",
details={"file_path": str(file_path)},
cause=e,
)
@classmethod
def from_fastapi(cls, app: "FastAPI") -> "OpenAPISpec":
"""Create an OpenAPISpec instance from a FastAPI application.
Args:
app: A FastAPI application instance.
Returns:
An OpenAPISpec instance.
Raises:
SpecInvalidError: If FastAPI is not installed, the app is not a FastAPI instance,
or the OpenAPI specification cannot be extracted.
"""
try:
from fastapi import FastAPI
except ImportError as e:
raise SpecInvalidError(
"FastAPI is not installed. Please install it with: pip install fastapi",
details={"missing_package": "fastapi"},
cause=e,
)
if not isinstance(app, FastAPI):
raise SpecInvalidError(
f"Expected a FastAPI instance, got {type(app).__name__}",
details={"app_type": type(app).__name__},
)
try:
# Get the OpenAPI schema from the FastAPI app
openapi_schema = app.openapi()
return cls.from_dict(openapi_schema)
except Exception as e:
raise SpecInvalidError(
f"Failed to extract OpenAPI specification from FastAPI app: {str(e)}",
details={"app_title": getattr(app, "title", "Unknown")},
cause=e,
)