22
33import ast
44from collections import defaultdict
5- from collections .abc import Mapping , Sequence
5+ from collections .abc import Iterable , Mapping , Sequence
66from copy import deepcopy
77from pathlib import Path
88
99import attrs
1010
11+ @attrs .frozen ()
12+ class TypeVarInfo :
13+ name : str
14+ bound : str | None = None
1115
1216@attrs .frozen ()
1317class ProtocolData :
1418 stmt : ast .ClassDef
15- typevars_used : Sequence [ str ]
19+ typevars_used : Iterable [ TypeVarInfo ]
1620 name : str
1721
1822
19- def _function_to_protocol (stmt : ast .FunctionDef , typevars : list [str ]) -> ProtocolData :
23+ def _function_to_protocol (stmt : ast .FunctionDef , typevars : list [TypeVarInfo ]) -> ProtocolData :
2024 """
2125 Convert a function definition to a Protocol class.
2226
@@ -41,51 +45,48 @@ def _function_to_protocol(stmt: ast.FunctionDef, typevars: list[str]) -> Protoco
4145 stmt .args .posonlyargs .insert (0 , ast .arg (arg = "self" ))
4246 stmt .decorator_list .append (ast .Name (id = "abstractmethod" ))
4347 args = ast .unparse (stmt .args )
44- typevars = [typevar for typevar in typevars if typevar in args ]
48+ typevars = [typevar for typevar in typevars if typevar . name in args ]
4549
4650 # Construct the protocol
4751 stmt_new = ast .ClassDef (
4852 name = name ,
4953 decorator_list = [ast .Name (id = "runtime_checkable" )],
5054 keywords = [],
5155 bases = [
52- ast .Subscript (
53- value = ast .Name (id = "Protocol" ),
54- slice = ast .Tuple (elts = [ast .Name (typevar ) for typevar in typevars ]),
55- )
56+ ast .Name (id = "Protocol" ),
5657 ],
5758 body = (
5859 [ast .Expr (value = ast .Constant (docstring , kind = None ))]
5960 if docstring is not None
6061 else []
6162 )
6263 + [stmt ],
63- type_params = [],
64+ type_params = [ast . TypeVar ( name = t . name , bound = ast . Name ( id = t . bound ) if t . bound else None ) for t in typevars ],
6465 ) # type: ignore[call-arg]
6566 return ProtocolData (
6667 stmt = stmt_new ,
6768 typevars_used = typevars ,
68- name = name + (f"[{ ', ' .join (typevars )} ]" if typevars else "" ),
69+ name = name + (f"[{ ', ' .join ([ t . name for t in typevars ] )} ]" if typevars else "" ),
6970 )
7071
7172
72- def _class_to_protocol (stmt : ast .ClassDef , typevars : list [str ]) -> ProtocolData :
73- typevars = [typevar for typevar in typevars if typevar in ast .unparse (stmt )]
73+ def _class_to_protocol (stmt : ast .ClassDef , typevars : list [TypeVarInfo ]) -> ProtocolData :
74+ typevars = [typevar for typevar in typevars if typevar . name in ast .unparse (stmt )]
7475 stmt .bases = [
75- ast .Subscript (
76- value = ast . Name ( id = "Protocol" ),
77- slice = ast . Tuple ( elts = [ ast . Name ( typevar ) for typevar in typevars ]),
78- )
76+ ast .Name ( id = "Protocol" ),
77+ ]
78+ stmt . type_params = [
79+ ast . TypeVar ( name = t . name , bound = ast . Name ( id = t . bound ) if t . bound else None ) for t in typevars
7980 ]
8081 return ProtocolData (
8182 stmt = stmt ,
8283 typevars_used = typevars ,
83- name = stmt .name + (f"[{ ', ' .join (typevars )} ]" if typevars else "" ),
84+ name = stmt .name + (f"[{ ', ' .join ([ t . name for t in typevars ] )} ]" if typevars else "" ),
8485 )
8586
8687
8788def _attributes_to_protocol (
88- name , attributes : list [tuple [str , str , str | None , list ]], typevars : list [ str ]
89+ name , attributes : list [tuple [str , str , str | None , list [ TypeVarInfo ]] ]
8990) -> ProtocolData :
9091 body = []
9192 for attribute , type , docstring , _ in attributes :
@@ -106,16 +107,15 @@ def _attributes_to_protocol(
106107 decorator_list = [ast .Name (id = "runtime_checkable" )],
107108 keywords = [],
108109 bases = [
109- ast .Subscript (
110- value = ast .Name (id = "Protocol" ),
111- slice = ast .Tuple (elts = [ast .Name (typevar ) for typevar in typevars ]),
112- )
110+ ast .Name (id = "Protocol" ),
113111 ],
114112 body = body ,
115- type_params = [],
113+ type_params = [
114+ ast .TypeVar (name = t .name , bound = ast .Name (id = t .bound ) if t .bound else None ) for t in typevars
115+ ],
116116 ),
117117 typevars_used = typevars ,
118- name = name + (f"[{ ', ' .join (typevars )} ]" if typevars else "" ),
118+ name = name + (f"[{ ', ' .join ([ t . name for t in typevars ] )} ]" if typevars else "" ),
119119 )
120120
121121
@@ -147,14 +147,14 @@ def generate(body_module: Mapping[str, list[ast.stmt]], out_path: Path) -> None:
147147 body_module .pop ("__init__" )
148148
149149 # Get all TypeVars
150- typevars : list [str ] = []
150+ typevars : list [TypeVarInfo ] = []
151151 for b in body_typevars :
152152 if isinstance (b , ast .Assign ):
153153 value = b .value
154154 if isinstance (value , ast .Call ):
155155 if value .func .id == "TypeVar" :
156156 name = value .args [0 ].s
157- typevars .append (name )
157+ typevars .append (TypeVarInfo ( name = name , bound = None ) )
158158 print (typevars )
159159
160160 # Dict of module attributes per submodule
@@ -184,17 +184,17 @@ def generate(body_module: Mapping[str, list[ast.stmt]], out_path: Path) -> None:
184184 ),
185185 )
186186
187- for typevar in typevars :
188- out .body .append (
189- ast .Assign (
190- targets = [ast .Name (id = typevar , ctx = ast .Store ())],
191- value = ast .Call (
192- func = ast .Name (id = "TypeVar" , ctx = ast .Load ()),
193- args = [ast .Constant (value = typevar )],
194- keywords = [],
195- ),
196- )
197- )
187+ # for typevar in typevars:
188+ # out.body.append(
189+ # ast.Assign(
190+ # targets=[ast.Name(id=typevar, ctx=ast.Store())],
191+ # value=ast.Call(
192+ # func=ast.Name(id="TypeVar", ctx=ast.Load()),
193+ # args=[ast.Constant(value=typevar)],
194+ # keywords=[],
195+ # ),
196+ # )
197+ # )
198198
199199 # Create Protocols with __call__, representing functions
200200 for submodule , body in body_module .items ():
@@ -244,7 +244,7 @@ def generate(body_module: Mapping[str, list[ast.stmt]], out_path: Path) -> None:
244244 if submodule not in OPTIONAL_SUBMODULES :
245245 continue
246246 data = _attributes_to_protocol (
247- submodule [0 ].upper () + submodule [1 :] + "Namespace" , attributes , typevars
247+ submodule [0 ].upper () + submodule [1 :] + "Namespace" , attributes
248248 )
249249 out .body .append (data .stmt )
250250 if submodule in OPTIONAL_SUBMODULES :
@@ -258,7 +258,7 @@ def generate(body_module: Mapping[str, list[ast.stmt]], out_path: Path) -> None:
258258 if submodule not in OPTIONAL_SUBMODULES
259259 ] + submodules
260260 out .body .append (
261- _attributes_to_protocol ("ArrayNamespace" , attributes , typevars ).stmt
261+ _attributes_to_protocol ("ArrayNamespace" , attributes ).stmt
262262 )
263263 out_path .parent .mkdir (parents = True , exist_ok = True )
264264 out_path .write_text (ast .unparse (ast .fix_missing_locations (out )), "utf-8" )
0 commit comments