@@ -16,8 +16,6 @@ class ProtocolData:
1616 name : str
1717
1818
19-
20-
2119def _function_to_protocol (stmt : ast .FunctionDef , typevars : list [str ]) -> ProtocolData :
2220 """
2321 Convert a function definition to a Protocol class.
@@ -56,7 +54,12 @@ def _function_to_protocol(stmt: ast.FunctionDef, typevars: list[str]) -> Protoco
5654 slice = ast .Tuple (elts = [ast .Name (typevar ) for typevar in typevars ]),
5755 )
5856 ],
59- body = ([ast .Expr (value = ast .Constant (docstring , kind = None ))] if docstring is not None else []) + [stmt ],
57+ body = (
58+ [ast .Expr (value = ast .Constant (docstring , kind = None ))]
59+ if docstring is not None
60+ else []
61+ )
62+ + [stmt ],
6063 type_params = [],
6164 ) # type: ignore[call-arg]
6265 return ProtocolData (
@@ -65,9 +68,8 @@ def _function_to_protocol(stmt: ast.FunctionDef, typevars: list[str]) -> Protoco
6568 name = name + (f"[{ ', ' .join (typevars )} ]" if typevars else "" ),
6669 )
6770
68- def _class_to_protocol (
69- stmt : ast .ClassDef , typevars : list [str ]
70- ) -> ProtocolData :
71+
72+ def _class_to_protocol (stmt : ast .ClassDef , typevars : list [str ]) -> ProtocolData :
7173 typevars = [typevar for typevar in typevars if typevar in ast .unparse (stmt )]
7274 stmt .bases = [
7375 ast .Subscript (
@@ -81,6 +83,7 @@ def _class_to_protocol(
8183 name = stmt .name + (f"[{ ', ' .join (typevars )} ]" if typevars else "" ),
8284 )
8385
86+
8487def _attributes_to_protocol (
8588 name , attributes : list [tuple [str , str , str | None , list ]], typevars : list [str ]
8689) -> ProtocolData :
@@ -129,7 +132,11 @@ def generate_all(
129132 for dir_path in (Path (cache_dir ) / Path ("src" ) / "array_api_stubs" ).glob ("**/" ):
130133 # get module bodies
131134 body_module = {
132- path .stem : ast .parse (path .read_text ("utf-8" )).body
135+ path .stem : ast .parse (
136+ path .read_text ("utf-8" )
137+ .replace ("Dtype" , "dtype" )
138+ .replace ("Device" , "device" )
139+ ).body
133140 for path in dir_path .rglob ("*.py" )
134141 }
135142 generate (body_module , Path (out_path ) / dir_path .name / out_name )
@@ -176,7 +183,7 @@ def generate(body_module: Mapping[str, list[ast.stmt]], out_path: Path) -> None:
176183 level = 0 ,
177184 ),
178185 )
179-
186+
180187 for typevar in typevars :
181188 out .body .append (
182189 ast .Assign (
@@ -202,6 +209,8 @@ def generate(body_module: Mapping[str, list[ast.stmt]], out_path: Path) -> None:
202209 module_attributes [submodule ].append (
203210 (b .name , data .name , None , data .typevars_used )
204211 )
212+ if "alias" in (ast .get_docstring (b ) or "" ):
213+ continue
205214 out .body .append (data .stmt )
206215 elif isinstance (b , ast .Assign ):
207216 if submodule == "_types" :
0 commit comments