Skip to content

Commit 7dc4acc

Browse files
author
Doug Borg
committed
feat: add OpenAPI 3.1 support (version detection, parsers, unified generate_data)
1 parent b1e1082 commit 7dc4acc

11 files changed

Lines changed: 1409 additions & 68 deletions

File tree

src/openapi_python_generator/__main__.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from openapi_python_generator.common import Formatter, HTTPLibrary, PydanticVersion
77
from openapi_python_generator.generate_data import generate_data
88

9+
910
@click.command()
1011
@click.argument("source")
1112
@click.argument("output")
@@ -63,15 +64,22 @@ def main(
6364
formatter: Formatter = Formatter.BLACK,
6465
) -> None:
6566
"""
66-
Generate Python code from an OpenAPI 3.0 specification.
67+
Generate Python code from an OpenAPI 3.0+ specification.
6768
68-
Provide a SOURCE (file or URL) containing the OpenAPI 3 specification and
69+
Provide a SOURCE (file or URL) containing the OpenAPI 3.0+ specification and
6970
an OUTPUT path, where the resulting client is created.
7071
"""
7172
generate_data(
72-
source, output, library, env_token_name, use_orjson, custom_template_path, pydantic_version, formatter
73+
source,
74+
output,
75+
library,
76+
env_token_name,
77+
use_orjson,
78+
custom_template_path,
79+
pydantic_version,
80+
formatter,
7381
)
7482

7583

7684
if __name__ == "__main__": # pragma: no cover
77-
main()
85+
main()

src/openapi_python_generator/generate_data.py

Lines changed: 82 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,27 @@
44
from typing import Union
55

66
import black
7+
from black.report import NothingChanged # type: ignore
78
import click
89
import httpx
910
import isort
1011
import orjson
11-
import yaml
12-
from black import NothingChanged
12+
import yaml # type: ignore
1313
from httpx import ConnectError
1414
from httpx import ConnectTimeout
15-
from openapi_pydantic.v3.v3_0 import OpenAPI
1615
from pydantic import ValidationError
1716

1817
from .common import FormatOptions, Formatter, HTTPLibrary, PydanticVersion
19-
from .common import library_config_dict
20-
from .language_converters.python.generator import generator
2118
from .language_converters.python.jinja_config import SERVICE_TEMPLATE
2219
from .language_converters.python.jinja_config import create_jinja_env
2320
from .models import ConversionResult
21+
from .version_detector import detect_openapi_version
22+
from .parsers import (
23+
parse_openapi_30,
24+
parse_openapi_31,
25+
generate_code_30,
26+
generate_code_31,
27+
)
2428

2529

2630
def write_code(path: Path, content: str, formatter: Formatter) -> None:
@@ -35,31 +39,36 @@ def write_code(path: Path, content: str, formatter: Formatter) -> None:
3539
elif formatter == Formatter.NONE:
3640
formatted_contend = content
3741
else:
38-
raise NotImplementedError(f"Missing implementation for formatter {formatter!r}.")
42+
raise NotImplementedError(
43+
f"Missing implementation for formatter {formatter!r}."
44+
)
3945
with open(path, "w") as f:
4046
f.write(formatted_contend)
4147

4248

4349
def format_using_black(content: str) -> str:
4450
try:
4551
formatted_contend = black.format_file_contents(
46-
content, fast=FormatOptions.skip_validation, mode=black.FileMode(line_length=FormatOptions.line_length)
52+
content,
53+
fast=FormatOptions.skip_validation,
54+
mode=black.FileMode(line_length=FormatOptions.line_length),
4755
)
4856
except NothingChanged:
4957
return content
5058
return isort.code(formatted_contend, line_length=FormatOptions.line_length)
5159

5260

53-
def get_open_api(source: Union[str, Path]) -> OpenAPI:
61+
def get_open_api(source: Union[str, Path]):
5462
"""
5563
Tries to fetch the openapi specification file from the web or load from a local file.
5664
Supports both JSON and YAML formats. Returns the according OpenAPI object.
65+
Automatically supports OpenAPI 3.0 and 3.1 specifications with intelligent version detection.
5766
5867
Args:
5968
source: URL or file path to the OpenAPI specification
6069
6170
Returns:
62-
OpenAPI: Parsed OpenAPI specification object
71+
tuple: (OpenAPI object, version) where version is "3.0" or "3.1"
6372
6473
Raises:
6574
FileNotFoundError: If the specified file cannot be found
@@ -70,31 +79,46 @@ def get_open_api(source: Union[str, Path]) -> OpenAPI:
7079
try:
7180
# Handle remote files
7281
if not isinstance(source, Path) and (
73-
source.startswith("http://") or source.startswith("https://")
82+
source.startswith("http://") or source.startswith("https://")
7483
):
7584
content = httpx.get(source).text
7685
# Try JSON first, then YAML for remote files
7786
try:
78-
return OpenAPI(**orjson.loads(content))
87+
data = orjson.loads(content)
7988
except orjson.JSONDecodeError:
80-
return OpenAPI(**yaml.safe_load(content))
81-
82-
# Handle local files
83-
with open(source, "r") as f:
84-
file_content = f.read()
89+
data = yaml.safe_load(content)
90+
else:
91+
# Handle local files
92+
with open(source, "r") as f:
93+
file_content = f.read()
8594

86-
# Try JSON first
87-
try:
88-
return OpenAPI(**orjson.loads(file_content))
89-
except orjson.JSONDecodeError:
90-
# If JSON fails, try YAML
95+
# Try JSON first
9196
try:
92-
return OpenAPI(**yaml.safe_load(file_content))
93-
except yaml.YAMLError as e:
94-
click.echo(
95-
f"File {source} is neither a valid JSON nor YAML file: {str(e)}"
96-
)
97-
raise
97+
data = orjson.loads(file_content)
98+
except orjson.JSONDecodeError:
99+
# If JSON fails, try YAML
100+
try:
101+
data = yaml.safe_load(file_content)
102+
except yaml.YAMLError as e:
103+
click.echo(
104+
f"File {source} is neither a valid JSON nor YAML file: {str(e)}"
105+
)
106+
raise
107+
108+
# Detect version and parse with appropriate parser
109+
version = detect_openapi_version(data)
110+
111+
if version == "3.0":
112+
openapi_obj = parse_openapi_30(data) # type: ignore[assignment]
113+
elif version == "3.1":
114+
openapi_obj = parse_openapi_31(data) # type: ignore[assignment]
115+
else:
116+
# Unsupported version detected (version detection already limited to 3.0 / 3.1)
117+
raise ValueError(
118+
f"Unsupported OpenAPI version: {version}. Only 3.0.x and 3.1.x are supported."
119+
)
120+
121+
return openapi_obj, version
98122

99123
except FileNotFoundError:
100124
click.echo(
@@ -105,13 +129,13 @@ def get_open_api(source: Union[str, Path]) -> OpenAPI:
105129
click.echo(f"Could not connect to {source}.")
106130
raise ConnectError(f"Could not connect to {source}.") from None
107131
except ValidationError:
108-
click.echo(
109-
f"File {source} is not a valid OpenAPI 3.0 specification."
110-
)
132+
click.echo(f"File {source} is not a valid OpenAPI 3.0+ specification.")
111133
raise
112134

113135

114-
def write_data(data: ConversionResult, output: Union[str, Path], formatter: Formatter) -> None:
136+
def write_data(
137+
data: ConversionResult, output: Union[str, Path], formatter: Formatter
138+
) -> None:
115139
"""
116140
This function will firstly create the folder structure of output, if it doesn't exist. Then it will create the
117141
models from data.models into the models sub module of the output folder. After this, the services will be created
@@ -156,7 +180,7 @@ def write_data(data: ConversionResult, output: Union[str, Path], formatter: Form
156180
files.append(service.file_name)
157181
write_code(
158182
services_path / f"{service.file_name}.py",
159-
jinja_env.get_template(SERVICE_TEMPLATE).render(**service.dict()),
183+
jinja_env.get_template(SERVICE_TEMPLATE).render(**service.model_dump()),
160184
formatter,
161185
)
162186

@@ -177,26 +201,39 @@ def write_data(data: ConversionResult, output: Union[str, Path], formatter: Form
177201
def generate_data(
178202
source: Union[str, Path],
179203
output: Union[str, Path],
180-
library: Optional[HTTPLibrary] = HTTPLibrary.httpx,
204+
library: HTTPLibrary = HTTPLibrary.httpx,
181205
env_token_name: Optional[str] = None,
182206
use_orjson: bool = False,
183207
custom_template_path: Optional[str] = None,
184208
pydantic_version: PydanticVersion = PydanticVersion.V2,
185209
formatter: Formatter = Formatter.BLACK,
186210
) -> None:
187211
"""
188-
Generate Python code from an OpenAPI 3.0 specification.
212+
Generate Python code from an OpenAPI 3.0+ specification.
189213
"""
190-
data = get_open_api(source)
191-
click.echo(f"Generating data from {source}")
192-
193-
result = generator(
194-
data,
195-
library_config_dict[library],
196-
env_token_name,
197-
use_orjson,
198-
custom_template_path,
199-
pydantic_version,
200-
)
214+
openapi_obj, version = get_open_api(source)
215+
click.echo(f"Generating data from {source} (OpenAPI {version})")
216+
217+
# Use version-specific generator
218+
if version == "3.0":
219+
result = generate_code_30(
220+
openapi_obj, # type: ignore
221+
library,
222+
env_token_name,
223+
use_orjson,
224+
custom_template_path,
225+
pydantic_version,
226+
)
227+
elif version == "3.1":
228+
result = generate_code_31(
229+
openapi_obj, # type: ignore
230+
library,
231+
env_token_name,
232+
use_orjson,
233+
custom_template_path,
234+
pydantic_version,
235+
)
236+
else:
237+
raise ValueError(f"Unsupported OpenAPI version: {version}")
201238

202239
write_data(result, output, formatter)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""
2+
OpenAPI parsers for different specification versions.
3+
"""
4+
5+
from .openapi_30 import parse_openapi_30, generate_code_30
6+
from .openapi_31 import parse_openapi_31, generate_code_31
7+
8+
__all__ = [
9+
"parse_openapi_30",
10+
"generate_code_30",
11+
"parse_openapi_31",
12+
"generate_code_31",
13+
]
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""
2+
OpenAPI 3.0 specific parsing and generation.
3+
"""
4+
5+
from typing import Optional
6+
7+
from openapi_pydantic.v3.v3_0 import OpenAPI
8+
9+
from openapi_python_generator.common import HTTPLibrary, PydanticVersion
10+
from openapi_python_generator.language_converters.python.generator import (
11+
generator as base_generator,
12+
)
13+
from openapi_python_generator.models import ConversionResult
14+
15+
16+
def parse_openapi_30(spec_data: dict) -> OpenAPI:
17+
"""
18+
Parse OpenAPI 3.0 specification data.
19+
20+
Args:
21+
spec_data: Dictionary containing OpenAPI 3.0 specification
22+
23+
Returns:
24+
OpenAPI: Parsed OpenAPI 3.0 specification object
25+
26+
Raises:
27+
ValidationError: If the specification is invalid
28+
"""
29+
return OpenAPI(**spec_data)
30+
31+
32+
def generate_code_30(
33+
data: OpenAPI,
34+
library: HTTPLibrary = HTTPLibrary.httpx,
35+
env_token_name: Optional[str] = None,
36+
use_orjson: bool = False,
37+
custom_template_path: Optional[str] = None,
38+
pydantic_version: PydanticVersion = PydanticVersion.V2,
39+
) -> ConversionResult:
40+
"""
41+
Generate Python code from OpenAPI 3.0 specification.
42+
43+
Args:
44+
data: OpenAPI 3.0 specification object
45+
library: HTTP library to use
46+
env_token_name: Environment variable name for token
47+
use_orjson: Whether to use orjson for serialization
48+
custom_template_path: Custom template path
49+
pydantic_version: Pydantic version to use
50+
51+
Returns:
52+
ConversionResult: Generated code and metadata
53+
"""
54+
from openapi_python_generator.common import library_config_dict
55+
56+
library_config = library_config_dict[library]
57+
58+
return base_generator(
59+
data=data,
60+
library_config=library_config,
61+
env_token_name=env_token_name,
62+
use_orjson=use_orjson,
63+
custom_template_path=custom_template_path,
64+
pydantic_version=pydantic_version,
65+
)

0 commit comments

Comments
 (0)