11from __future__ import annotations
22
33import ast
4- import runpy
54import warnings
65from collections import defaultdict
76from collections .abc import Iterable , Sequence
@@ -138,7 +137,7 @@ def _class_to_protocol(stmt: ast.ClassDef, typevars: Sequence[TypeVarInfo]) -> P
138137 )
139138
140139
141- def _attributes_to_protocol (name : str , attributes : Sequence [ModuleAttributes ]) -> ProtocolData :
140+ def _attributes_to_protocol (name : str , attributes : Sequence [ModuleAttributes ], bases : list [ ast . expr ] | None = None , typevars : Sequence [ TypeVarInfo ] | None = None ) -> ProtocolData :
142141 """
143142 Convert a list of module attributes to a Protocol class.
144143
@@ -148,6 +147,10 @@ def _attributes_to_protocol(name: str, attributes: Sequence[ModuleAttributes]) -
148147 The name of the Protocol class.
149148 attributes : Sequence[ModuleAttributes]
150149 The attributes to include in the Protocol class.
150+ bases : list[ast.expr] | None, optional
151+ The base classes for the Protocol class, by default None, which defaults to [Protocol].
152+ typevars : Sequence[TypeVarInfo] | None, optional
153+ The type variables used in the Protocol class, by default None, which defaults to the type variables used in the attributes.
151154
152155 Returns
153156 -------
@@ -166,14 +169,16 @@ def _attributes_to_protocol(name: str, attributes: Sequence[ModuleAttributes]) -
166169 )
167170 if a .docstring is not None :
168171 body .append (ast .Expr (value = ast .Constant (a .docstring )))
169-
170- typevars = sorted ({x for attribute in attributes for x in attribute .typevars_used }, key = lambda x : x .name )
172+ if typevars is None :
173+ typevars = sorted ({x for attribute in attributes for x in attribute .typevars_used }, key = lambda x : x .name )
171174 return ProtocolData (
172175 stmt = ast .ClassDef (
173176 name = name ,
174177 decorator_list = [ast .Name (id = "runtime_checkable" )],
175178 keywords = [],
176- bases = [
179+ bases = bases
180+ if bases
181+ else [
177182 ast .Name (id = "Protocol" ),
178183 ],
179184 body = body ,
@@ -272,26 +277,31 @@ def generate(body_module: dict[str, list[ast.stmt]], out_path: Path) -> None:
272277 elif isinstance (b , ast .ClassDef ):
273278 data = _class_to_protocol (b , typevars )
274279 # add to output, do not add to module attributes
275- out .body .append (data .stmt )
280+ # add to first position
281+ out .body .insert (0 , data .stmt )
276282 elif isinstance (b , ast .Expr ):
277283 pass
278284 else :
279285 warnings .warn (f"Skipping { submodule } { b } " , stacklevel = 2 )
280286
287+ # Create Protocols for the main namespace
288+ OPTIONAL_SUBMODULES = ["fft" , "linalg" ]
289+ main_attributes = [attribute for submodule , attributes in module_attributes .items () for attribute in attributes if submodule not in OPTIONAL_SUBMODULES ]
290+ main_protocol = _attributes_to_protocol ("ArrayNamespace" , main_attributes ).stmt
291+ out .body .append (main_protocol )
292+
281293 # Create Protocols for fft and linalg
282294 submodules : list [ModuleAttributes ] = []
283- OPTIONAL_SUBMODULES = ["fft" , "linalg" ]
284295 for submodule , attributes in module_attributes .items ():
285296 if submodule not in OPTIONAL_SUBMODULES :
286297 continue
287298 data = _attributes_to_protocol (submodule [0 ].upper () + submodule [1 :] + "Namespace" , attributes )
288299 out .body .append (data .stmt )
289300 if submodule in OPTIONAL_SUBMODULES :
290- submodules .append (ModuleAttributes (submodule , data .name , None , []))
301+ submodules .append (ModuleAttributes (submodule , data .name , None , [t for t in typevars if any ( t in attr . typevars_used for attr in attributes ) ]))
291302
292- # Create Protocols for the main namespace
293- attributes = [attribute for submodule , attributes in module_attributes .items () for attribute in attributes if submodule not in OPTIONAL_SUBMODULES ] + submodules
294- out .body .append (_attributes_to_protocol ("ArrayNamespace" , attributes ).stmt )
303+ # Create Full Protocol for the main namespace
304+ out .body .append (_attributes_to_protocol ("ArrayNamespaceFull" , submodules , [ast .Subscript (ast .Name ("ArrayNamespace" ), ast .Tuple ([ast .Name (t .name ) for t in main_protocol .type_params ]))], typevars = [t for t in typevars if t .name in [s .name for s in main_protocol .type_params ]]).stmt ) # type: ignore
295305
296306 # Replace TypeVars because of the name conflicts like "array: array"
297307 for node in ast .walk (out ):
@@ -374,9 +384,3 @@ def generate_all(
374384 # get module bodies
375385 body_module = {path .stem : ast .parse (path .read_text ("utf-8" ).replace ("Dtype" , "dtype" ).replace ("Device" , "device" )).body for path in dir_path .rglob ("*.py" )}
376386 generate (body_module , (Path (out_path ) / dir_path .name ).with_suffix (".py" ))
377-
378- import sys
379-
380- # run ssort, otherwise it is broken
381- sys .argv = ["ssort" , "src/array_api" ]
382- runpy .run_module ("ssort" )
0 commit comments