-
Notifications
You must be signed in to change notification settings - Fork 39
Expand file tree
/
Copy pathgenerate_data.py
More file actions
239 lines (205 loc) · 8 KB
/
generate_data.py
File metadata and controls
239 lines (205 loc) · 8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
from pathlib import Path
from typing import List, Optional, Union
import black
import click
import httpx
import isort
import orjson
import yaml # type: ignore
from black.report import NothingChanged # type: ignore
from httpx import ConnectError, ConnectTimeout
from pydantic import ValidationError
from .common import FormatOptions, Formatter, HTTPLibrary, PydanticVersion
from .language_converters.python.jinja_config import SERVICE_TEMPLATE, create_jinja_env
from .models import ConversionResult
from .parsers import (
generate_code_3_0,
generate_code_3_1,
parse_openapi_3_0,
parse_openapi_3_1,
)
from .version_detector import detect_openapi_version
def write_code(path: Path, content: str, formatter: Formatter) -> None:
"""
Write the content to the file at the given path.
:param path: The path to the file.
:param content: The content to write.
:param formatter: The formatter applied to the code written.
"""
if formatter == Formatter.BLACK:
formatted_contend = format_using_black(content)
elif formatter == Formatter.NONE:
formatted_contend = content
else:
raise NotImplementedError(
f"Missing implementation for formatter {formatter!r}."
)
with open(path, "w") as f:
f.write(formatted_contend)
def format_using_black(content: str) -> str:
try:
formatted_contend = black.format_file_contents(
content,
fast=FormatOptions.skip_validation,
mode=black.FileMode(line_length=FormatOptions.line_length),
)
except NothingChanged:
return content
return isort.code(formatted_contend, line_length=FormatOptions.line_length)
def get_open_api(source: Union[str, Path]):
"""
Tries to fetch the openapi specification file from the web or load from a local file.
Supports both JSON and YAML formats. Returns the according OpenAPI object.
Automatically supports OpenAPI 3.0 and 3.1 specifications with intelligent version detection.
Args:
source: URL or file path to the OpenAPI specification
Returns:
tuple: (OpenAPI object, version) where version is "3.0" or "3.1"
Raises:
FileNotFoundError: If the specified file cannot be found
ConnectError: If the URL cannot be accessed
ValidationError: If the specification is invalid
JSONDecodeError/YAMLError: If the file cannot be parsed
"""
try:
# Handle remote files
if not isinstance(source, Path) and (
source.startswith("http://") or source.startswith("https://")
):
content = httpx.get(source).text
# Try JSON first, then YAML for remote files
try:
data = orjson.loads(content)
except orjson.JSONDecodeError:
data = yaml.safe_load(content)
else:
# Handle local files
with open(source, "r") as f:
file_content = f.read()
# Try JSON first
try:
data = orjson.loads(file_content)
except orjson.JSONDecodeError:
# If JSON fails, try YAML
try:
data = yaml.safe_load(file_content)
except yaml.YAMLError as e:
click.echo(
f"File {source} is neither a valid JSON nor YAML file: {str(e)}"
)
raise
# Detect version and parse with appropriate parser
version = detect_openapi_version(data)
if version == "3.0":
openapi_obj = parse_openapi_3_0(data) # type: ignore[assignment]
elif version == "3.1":
openapi_obj = parse_openapi_3_1(data) # type: ignore[assignment]
else:
# Unsupported version detected (version detection already limited to 3.0 / 3.1)
raise ValueError(
f"Unsupported OpenAPI version: {version}. Only 3.0.x and 3.1.x are supported."
)
return openapi_obj, version
except FileNotFoundError:
click.echo(
f"File {source} not found. Please make sure to pass the path to the OpenAPI specification."
)
raise
except (ConnectError, ConnectTimeout):
click.echo(f"Could not connect to {source}.")
raise ConnectError(f"Could not connect to {source}.") from None
except ValidationError:
click.echo(f"File {source} is not a valid OpenAPI 3.0+ specification.")
raise
def write_data(
data: ConversionResult, output: Union[str, Path], formatter: Formatter
) -> None:
"""
This function will firstly create the folder structure of output, if it doesn't exist. Then it will create the
models from data.models into the models sub module of the output folder. After this, the services will be created
into the services sub module of the output folder.
:param data: The data to write.
:param output: The path to the output folder.
:param formatter: The formatter applied to the code written.
"""
# Create the folder structure of the output folder.
Path(output).mkdir(parents=True, exist_ok=True)
# Create the models module.
models_path = Path(output) / "models"
models_path.mkdir(parents=True, exist_ok=True)
# Create the services module.
services_path = Path(output) / "services"
services_path.mkdir(parents=True, exist_ok=True)
files: List[str] = []
# Write the models.
for model in data.models:
files.append(model.file_name)
write_code(models_path / f"{model.file_name}.py", model.content, formatter)
# Create models.__init__.py file containing imports to all models.
write_code(
models_path / "__init__.py",
"\n".join([f"from .{file} import *" for file in files]),
formatter,
)
files = []
# Write the services.
jinja_env = create_jinja_env()
for service in data.services:
if len(service.operations) == 0:
continue
files.append(service.file_name)
write_code(
services_path / f"{service.file_name}.py",
jinja_env.get_template(SERVICE_TEMPLATE).render(**service.model_dump()),
formatter,
)
# Create services.__init__.py file containing imports to all services.
write_code(services_path / "__init__.py", "", formatter)
# Write the api_config.py file.
write_code(Path(output) / "api_config.py", data.api_config.content, formatter)
# Write static files (if any).
for static_file in data.static_files:
write_code(Path(output) / f"{static_file.file_name}.py", static_file.content, formatter)
# Write the __init__.py file.
write_code(
Path(output) / "__init__.py",
"from .models import *\nfrom .services import *\nfrom .api_config import *",
formatter,
)
def generate_data(
source: Union[str, Path],
output: Union[str, Path],
library: HTTPLibrary = HTTPLibrary.httpx,
env_token_name: Optional[str] = None,
use_orjson: bool = False,
custom_template_path: Optional[str] = None,
pydantic_version: PydanticVersion = PydanticVersion.V2,
formatter: Formatter = Formatter.BLACK,
) -> None:
"""
Generate Python code from an OpenAPI 3.0+ specification.
"""
openapi_obj, version = get_open_api(source)
click.echo(f"Generating data from {source} (OpenAPI {version})")
# Use version-specific generator
if version == "3.0":
result = generate_code_3_0(
openapi_obj, # type: ignore
library,
env_token_name,
use_orjson,
custom_template_path,
pydantic_version,
)
elif version == "3.1":
result = generate_code_3_1(
openapi_obj, # type: ignore
library,
env_token_name,
use_orjson,
custom_template_path,
pydantic_version,
)
else:
raise ValueError(f"Unsupported OpenAPI version: {version}")
write_data(result, output, formatter)