Skip to content

Commit eeab34f

Browse files
author
Anders Brams
committed
fix: writing type aliases as python3.12 type declarations
1 parent ec85bb0 commit eeab34f

6 files changed

Lines changed: 129 additions & 8 deletions

File tree

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{{ alias.name }}: TypeAlias = {{ alias.annotation | annotation }}
1+
type {{ alias.name }} = {{ alias.annotation | annotation }}

openapi_python/generator/templates/types.py.j2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ from enum import Enum
55
{% if has_field_descriptions -%}
66
from importlib import import_module
77
{% endif -%}
8-
from typing import Annotated, Any, Literal, NotRequired, TypeAlias, TypedDict
8+
from typing import Annotated, Any, Literal, NotRequired, TypedDict
99

1010
{% if has_field_descriptions %}
1111
def _openapi_python_field(description: str) -> object:

tests/contract/discriminated_union/app.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from fastapi import FastAPI
66
from pydantic import BaseModel, Field
77

8-
app = FastAPI()
8+
app = FastAPI(separate_input_output_schemas=False)
99

1010

1111
class Cat(BaseModel):
@@ -34,3 +34,35 @@ def get_pet(pet_id: int) -> PetEnvelope:
3434
@app.post("/pets", response_model=PetEnvelope)
3535
def create_pet(body: PetEnvelope) -> PetEnvelope:
3636
return body
37+
38+
39+
class CiscoSiteArea(BaseModel):
40+
type: Literal["area"]
41+
name: str
42+
43+
44+
class CiscoSiteBuilding(BaseModel):
45+
type: Literal["building"]
46+
name: str
47+
country: str
48+
49+
50+
class CiscoSiteFloor(BaseModel):
51+
type: Literal["floor"]
52+
name: str
53+
floor_number: int
54+
55+
56+
type CiscoSite = Annotated[
57+
CiscoSiteArea | CiscoSiteBuilding | CiscoSiteFloor,
58+
Field(discriminator="type"),
59+
]
60+
61+
62+
class CiscoAccessPointConfig(BaseModel):
63+
site_hierarchy: list[CiscoSite]
64+
65+
66+
@app.post("/sites", response_model=CiscoAccessPointConfig)
67+
def create_site(body: CiscoAccessPointConfig) -> CiscoAccessPointConfig:
68+
return body

tests/contract/discriminated_union/generate.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,54 @@
1-
from __future__ import annotations
2-
31
import json
42
from pathlib import Path
3+
from typing import cast
54

65
from app import app
6+
from fastapi import FastAPI
77

88
from openapi_python.generator import GenerationRequest, generate_client
99

10+
OUTPUT_DIR = Path(__file__).parent / "generated"
11+
12+
13+
def _service_b_openapi() -> dict:
14+
from generated.service_a_client.types import CiscoAccessPointConfig, PetEnvelope
15+
16+
service_b = FastAPI(separate_input_output_schemas=False)
17+
18+
@service_b.get("/pets/{pet_id}", response_model=PetEnvelope)
19+
def get_pet(pet_id: int) -> PetEnvelope:
20+
return cast(
21+
PetEnvelope,
22+
{"pet": {"pet_type": "cat", "lives": 9}, "request_id": f"pet_{pet_id}"},
23+
)
24+
25+
@service_b.post("/pets", response_model=PetEnvelope)
26+
def create_pet(body: PetEnvelope) -> PetEnvelope:
27+
return body
28+
29+
@service_b.post("/sites", response_model=CiscoAccessPointConfig)
30+
def create_site(body: CiscoAccessPointConfig) -> CiscoAccessPointConfig:
31+
return body
32+
33+
return service_b.openapi()
34+
1035

1136
def main() -> None:
1237
generate_client(
1338
GenerationRequest(
14-
output_dir=Path(__file__).parent / "generated",
39+
output_dir=OUTPUT_DIR,
40+
package_name="service_a_client",
1541
spec_json=json.dumps(app.openapi()),
1642
overwrite=True,
1743
)
1844
)
45+
generate_client(
46+
GenerationRequest(
47+
output_dir=OUTPUT_DIR,
48+
spec_json=json.dumps(_service_b_openapi()),
49+
overwrite=True,
50+
)
51+
)
1952

2053

2154
if __name__ == "__main__":

tests/contract/discriminated_union/usage_async.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,16 @@
33
from typing import Literal, assert_type
44

55
from generated.my_client import AsyncClient
6-
from generated.my_client.types import Cat, Dog, PetEnvelope
6+
from generated.my_client.types import (
7+
Cat,
8+
CiscoAccessPointConfig,
9+
CiscoSite,
10+
CiscoSiteArea,
11+
CiscoSiteBuilding,
12+
CiscoSiteFloor,
13+
Dog,
14+
PetEnvelope,
15+
)
716

817
async_client = AsyncClient(base_url="http://testserver")
918

@@ -29,3 +38,22 @@ async def use_async_client() -> None:
2938
fetched = await async_client.get("/pets/{pet_id}")(params={"pet_id": 1})
3039
assert_type(fetched, PetEnvelope)
3140
assert_type(fetched["pet"], Cat | Dog)
41+
42+
area: CiscoSiteArea = {"type": "area", "name": "Global"}
43+
building: CiscoSiteBuilding = {
44+
"type": "building",
45+
"name": "HQ",
46+
"country": "DK",
47+
}
48+
floor: CiscoSiteFloor = {
49+
"type": "floor",
50+
"name": "Ground",
51+
"floor_number": 0,
52+
}
53+
site: CiscoSite = floor
54+
body: CiscoAccessPointConfig = {"site_hierarchy": [area, building, site]}
55+
56+
created = await async_client.post("/sites")(body=body)
57+
assert_type(created, CiscoAccessPointConfig)
58+
assert_type(created["site_hierarchy"][0], CiscoSite)
59+
assert_type(floor["type"], Literal["floor"])

tests/contract/discriminated_union/usage_sync.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,16 @@
33
from typing import Literal, assert_type
44

55
from generated.my_client import Client
6-
from generated.my_client.types import Cat, Dog, PetEnvelope
6+
from generated.my_client.types import (
7+
Cat,
8+
CiscoAccessPointConfig,
9+
CiscoSite,
10+
CiscoSiteArea,
11+
CiscoSiteBuilding,
12+
CiscoSiteFloor,
13+
Dog,
14+
PetEnvelope,
15+
)
716

817
client = Client(base_url="http://testserver")
918

@@ -27,3 +36,22 @@
2736
fetched = client.get("/pets/{pet_id}")(params={"pet_id": 1})
2837
assert_type(fetched, PetEnvelope)
2938
assert_type(fetched["pet"], Cat | Dog)
39+
40+
area: CiscoSiteArea = {"type": "area", "name": "Global"}
41+
building: CiscoSiteBuilding = {
42+
"type": "building",
43+
"name": "HQ",
44+
"country": "DK",
45+
}
46+
floor: CiscoSiteFloor = {
47+
"type": "floor",
48+
"name": "Ground",
49+
"floor_number": 0,
50+
}
51+
site: CiscoSite = floor
52+
body: CiscoAccessPointConfig = {"site_hierarchy": [area, building, site]}
53+
54+
created = client.post("/sites")(body=body)
55+
assert_type(created, CiscoAccessPointConfig)
56+
assert_type(created["site_hierarchy"][0], CiscoSite)
57+
assert_type(floor["type"], Literal["floor"])

0 commit comments

Comments
 (0)