@@ -140,13 +140,14 @@ def generate(body_module: Mapping[str, list[ast.stmt]], out_path: Path) -> None:
140140 body_module .pop ("__init__" )
141141
142142 # Get all TypeVars
143- typevars = []
143+ typevars : list [ str ] = []
144144 for b in body_typevars :
145145 if isinstance (b , ast .Assign ):
146146 value = b .value
147147 if isinstance (value , ast .Call ):
148148 if value .func .id == "TypeVar" :
149- typevars .append (value .args [0 ].s )
149+ name = value .args [0 ].s
150+ typevars .append (name )
150151 print (typevars )
151152
152153 # Dict of module attributes per submodule
@@ -175,6 +176,18 @@ def generate(body_module: Mapping[str, list[ast.stmt]], out_path: Path) -> None:
175176 level = 0 ,
176177 ),
177178 )
179+
180+ for typevar in typevars :
181+ out .body .append (
182+ ast .Assign (
183+ targets = [ast .Name (id = typevar , ctx = ast .Store ())],
184+ value = ast .Call (
185+ func = ast .Name (id = "TypeVar" , ctx = ast .Load ()),
186+ args = [ast .Constant (value = typevar )],
187+ keywords = [],
188+ ),
189+ )
190+ )
178191
179192 # Create Protocols with __call__, representing functions
180193 for submodule , body in body_module .items ():
@@ -239,4 +252,4 @@ def generate(body_module: Mapping[str, list[ast.stmt]], out_path: Path) -> None:
239252 _attributes_to_protocol ("ArrayNamespace" , attributes , typevars ).stmt
240253 )
241254 out_path .parent .mkdir (parents = True , exist_ok = True )
242- out_path .write_text (ast .unparse (out ), "utf-8" )
255+ out_path .write_text (ast .unparse (ast . fix_missing_locations ( out ) ), "utf-8" )
0 commit comments