@@ -135,7 +135,7 @@ def _class_to_protocol(stmt: ast.ClassDef, typevars: Sequence[TypeVarInfo]) -> P
135135 )
136136
137137
138- def _attributes_to_protocol (name : str , attributes : Sequence [ModuleAttributes ], bases : list [ast .expr ] | None = None , typevars : Sequence [TypeVarInfo ] | None = None ) -> ProtocolData :
138+ def _attributes_to_protocol (name : str , attributes : Sequence [ModuleAttributes ], / , * , typevars : Sequence [ TypeVarInfo ], bases : list [ast .expr ] | None = None , typevars_force : Sequence [TypeVarInfo ] | None = None ) -> ProtocolData :
139139 """
140140 Convert a list of module attributes to a Protocol class.
141141
@@ -147,7 +147,9 @@ def _attributes_to_protocol(name: str, attributes: Sequence[ModuleAttributes], b
147147 The attributes to include in the Protocol class.
148148 bases : list[ast.expr] | None, optional
149149 The base classes for the Protocol class, by default None, which defaults to [Protocol].
150- typevars : Sequence[TypeVarInfo] | None, optional
150+ typevars : Sequence[TypeVarInfo]
151+ The type variables used in the class.
152+ typevars_force : Sequence[TypeVarInfo] | None, optional
151153 The type variables used in the Protocol class, by default None, which defaults to the type variables used in the attributes.
152154
153155 Returns
@@ -167,8 +169,8 @@ def _attributes_to_protocol(name: str, attributes: Sequence[ModuleAttributes], b
167169 )
168170 if a .docstring is not None :
169171 body .append (ast .Expr (value = ast .Constant (a .docstring )))
170- if typevars is None :
171- typevars = sorted ({ x for attribute in attributes for x in attribute .typevars_used }, key = lambda x : x . name . lower ())
172+ if typevars_force is None :
173+ typevars_force = [ t for t in typevars if any ( t in attr .typevars_used for attr in attributes )]
172174 return ProtocolData (
173175 stmt = ast .ClassDef (
174176 name = name ,
@@ -179,9 +181,9 @@ def _attributes_to_protocol(name: str, attributes: Sequence[ModuleAttributes], b
179181 ast .Name (id = "Protocol" ),
180182 ],
181183 body = body ,
182- type_params = [ast .TypeVar (name = t .name , bound = ast .Name (id = t .bound ) if t .bound else None ) for t in typevars ],
184+ type_params = [ast .TypeVar (name = t .name , bound = ast .Name (id = t .bound ) if t .bound else None ) for t in typevars_force ],
183185 ),
184- typevars_used = typevars ,
186+ typevars_used = typevars_force ,
185187 )
186188
187189
@@ -197,31 +199,11 @@ def generate(body_module: dict[str, list[ast.stmt]], out_path: Path) -> None:
197199 The output path where the generated Protocol classes will be saved.
198200
199201 """
200- body_typevars = body_module ["_types" ]
202+ body_module ["_types" ]
201203 del body_module ["__init__" ]
202204
203205 # Get all TypeVars
204- typevars : list [TypeVarInfo ] = []
205- for b in body_typevars :
206- if isinstance (b , ast .Assign ):
207- value = b .value
208- if isinstance (value , ast .Call ):
209- if isinstance (value .func , ast .Name ):
210- if value .func .id == "TypeVar" :
211- if isinstance (value .args [0 ], ast .Constant ):
212- name = value .args [0 ].s
213- typevars .append (
214- TypeVarInfo (
215- name = name ,
216- bound = {
217- "array" : "_array" ,
218- }.get (name , None ),
219- )
220- )
221- typevars += [TypeVarInfo (name = x ) for x in ["Capabilities" , "DefaultDataTypes" , "DataTypes" ]]
222- typevars = [t for t in typevars if t .name not in ["ellipsis" , "PyCapsule" , "SupportsBufferProtocol" ]]
223- print (typevars )
224- typevars = sorted (typevars , key = lambda x : x .name .lower ())
206+ typevars = [TypeVarInfo ("array" , "_array" ), TypeVarInfo ("dtype" ), TypeVarInfo ("device" ), TypeVarInfo ("_T_co" )]
225207
226208 # Dict of module attributes per submodule
227209 module_attributes : defaultdict [str , list [ModuleAttributes ]] = defaultdict (list )
@@ -254,6 +236,9 @@ def generate(body_module: dict[str, list[ast.stmt]], out_path: Path) -> None:
254236 elif isinstance (b , ast .Assign ):
255237 # _types.py contains Assigns which are not part of the Namespace
256238 if submodule == "_types" :
239+ if isinstance (b .targets [0 ], ast .Name ) and b .targets [0 ].id in ["Capabilities" , "DefaultDataTypes" , "DataTypes" ]:
240+ b = ast .parse (ast .unparse (b ).replace ("dtype" , "Any" )) # type: ignore
241+ out .body = [b , * out .body ]
257242 continue
258243 if not isinstance (b .targets [0 ], ast .Name ):
259244 continue
@@ -291,21 +276,21 @@ def generate(body_module: dict[str, list[ast.stmt]], out_path: Path) -> None:
291276 # Create Protocols for the main namespace
292277 OPTIONAL_SUBMODULES = ["fft" , "linalg" ]
293278 main_attributes = [attribute for submodule , attributes in module_attributes .items () for attribute in attributes if submodule not in OPTIONAL_SUBMODULES ]
294- main_protocol = _attributes_to_protocol ("ArrayNamespace" , main_attributes ).stmt
279+ main_protocol = _attributes_to_protocol ("ArrayNamespace" , main_attributes , typevars = typevars ).stmt
295280 out .body .append (main_protocol )
296281
297282 # Create Protocols for fft and linalg
298283 submodules : list [ModuleAttributes ] = []
299284 for submodule , attributes in module_attributes .items ():
300285 if submodule not in OPTIONAL_SUBMODULES :
301286 continue
302- data = _attributes_to_protocol (submodule [0 ].upper () + submodule [1 :] + "Namespace" , attributes )
287+ data = _attributes_to_protocol (submodule [0 ].upper () + submodule [1 :] + "Namespace" , attributes , typevars = typevars )
303288 out .body .append (data .stmt )
304289 if submodule in OPTIONAL_SUBMODULES :
305290 submodules .append (ModuleAttributes (submodule , data .name , None , [t for t in typevars if any (t in attr .typevars_used for attr in attributes )]))
306291
307292 # Create Full Protocol for the main namespace
308- out .body .append (_attributes_to_protocol ("ArrayNamespaceFull" , submodules , [ast .Subscript (ast .Name ("ArrayNamespace" ), ast .Tuple ([ast .Name (t .name ) for t in main_protocol .type_params ]))], typevars = [t for t in typevars if t .name in [s .name for s in main_protocol .type_params ]]).stmt ) # type: ignore
293+ out .body .append (_attributes_to_protocol ("ArrayNamespaceFull" , submodules , bases = [ast .Subscript (ast .Name ("ArrayNamespace" ), ast .Tuple ([ast .Name (t .name ) for t in main_protocol .type_params ]))], typevars = typevars , typevars_force = [t for t in typevars if t .name in [s .name for s in main_protocol .type_params ]]).stmt ) # type: ignore
309294
310295 # Replace TypeVars because of the name conflicts like "array: array"
311296 for node in ast .walk (out ):
@@ -344,6 +329,7 @@ def generate(body_module: dict[str, list[ast.stmt]], out_path: Path) -> None:
344329 Tuple,
345330 List,
346331 runtime_checkable,
332+ TypedDict
347333)
348334from types import EllipsisType as ellipsis
349335from typing_extensions import CapsuleType as PyCapsule
0 commit comments