Skip to content

Commit 6590ee7

Browse files
committed
Extract type information from the '_type_adapter' attribute instead of the 'type_' attribute of the path parameter ModelField instance; 'type_' is a Pydantic v1 attribute that has been removed
1 parent d5b2c1f commit 6590ee7

1 file changed

Lines changed: 36 additions & 3 deletions

File tree

src/murfey/cli/generate_route_manifest.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from argparse import ArgumentParser
1212
from pathlib import Path
1313
from types import ModuleType
14-
from typing import Any
14+
from typing import Annotated, Union, get_args, get_origin
1515

1616
import yaml
1717
from fastapi import APIRouter
@@ -20,6 +20,39 @@
2020
from murfey.cli import PrettierDumper
2121

2222

23+
def extract_base_type(annotation):
24+
"""
25+
Given a Python type annotation, return its underlying base type.
26+
27+
This function unwraps `typing.Annotated` to extract the annotated type
28+
and simplifies `Optional[T]` / `Union[T, None]` to `T`. All other union
29+
types and complex annotations are returned unchanged.
30+
31+
Parameters
32+
----------
33+
annotation:
34+
A Python type annotation (e.g. int, Annotated[int, ...], Optional[int],
35+
Union[int, str])
36+
37+
Returns
38+
-------
39+
The unwrapped base type, or the original annotation if no unambiguous base
40+
type can be determined.
41+
"""
42+
# Unwrap Annotated
43+
if get_origin(annotation) is Annotated:
44+
annotation = get_args(annotation)[0]
45+
46+
# Unwrap Optional / Union:param
47+
origin = get_origin(annotation)
48+
if origin is Union:
49+
args = [a for a in get_args(annotation) if a is not type(None)]
50+
if len(args) == 1:
51+
return args[0]
52+
53+
return annotation
54+
55+
2356
def find_routers(name: str) -> dict[str, APIRouter]:
2457
def _extract_routers_from_module(module: ModuleType):
2558
routers = {}
@@ -74,7 +107,7 @@ def get_route_manifest(routers: dict[str, APIRouter]):
74107
for route in router.routes:
75108
path_params = []
76109
for param in route.dependant.path_params:
77-
param_type = param.type_ if param.type_ is not None else Any
110+
param_type = extract_base_type(param._type_adapter._type)
78111
param_info = {
79112
"name": param.name if hasattr(param, "name") else "",
80113
"type": (
@@ -86,7 +119,7 @@ def get_route_manifest(routers: dict[str, APIRouter]):
86119
path_params.append(param_info)
87120
for route_dependency in route.dependant.dependencies:
88121
for param in route_dependency.path_params:
89-
param_type = param.type_ if param.type_ is not None else Any
122+
param_type = extract_base_type(param._type_adapter._type)
90123
param_info = {
91124
"name": param.name if hasattr(param, "name") else "",
92125
"type": (

0 commit comments

Comments
 (0)