1717from cuda .tile ._ir import hir
1818from cuda .tile ._ir .type import ClosureDefaultPlaceholder
1919from cuda .tile ._passes .ast_util import ast_get_all_local_names
20- from cuda .tile ._stub import static_eval , static_assert
20+ from cuda .tile ._stub import static_eval , static_assert , static_iter
2121
2222
2323@lru_cache
@@ -145,6 +145,7 @@ def _get_function_hir_inner(func_def: ast.FunctionDef | ast.Lambda, signature: i
145145
146146class LoopKind (Enum ):
147147 FOR = auto ()
148+ STATIC_FOR = auto ()
148149 WHILE = auto ()
149150
150151
@@ -247,12 +248,13 @@ def decorate(f):
247248_expr_handlers : Dict [Type [ast .AST ], Callable ] = {}
248249
249250
251+ _KEYWORD_LIKE_FUNCS = (static_eval , static_assert , static_iter )
252+ _KEYWORD_LIKE_FUNC_NAMES = ("static_eval" , "static_assert" , "static_iter" )
253+
254+
250255@_register (_expr_handlers , ast .Call )
251256def _call_expr (call : ast .Call , ctx : _Context ) -> hir .Value :
252- kwd_func = _parse_keyword_like_func (call .func ,
253- (static_eval , static_assert ),
254- ("static_eval" , "static_assert" ),
255- ctx )
257+ kwd_func = _parse_keyword_like_func (call .func , ctx )
256258 if kwd_func is not None :
257259 if kwd_func == "static_eval" :
258260 if len (call .args ) != 1 or len (call .keywords ) != 0 :
@@ -272,8 +274,11 @@ def _call_expr(call: ast.Call, ctx: _Context) -> hir.Value:
272274 condition = _call_static_eval (call .args [0 ],
273275 hir .StaticEvalKind .STATIC_ASSERT_CONDITION , ctx )
274276 return ctx .call (hir .do_static_assert , (condition , message_block ))
277+ elif kwd_func == "static_iter" :
278+ raise TileSyntaxError ("static_iter() is only allowed as iterable in a `for` loop,"
279+ " i.e. `for i in ct.static_iter(...)`" )
275280 else :
276- assert False
281+ raise TileSyntaxError ( f" { kwd_func } is not expected here" )
277282 else :
278283 callee = _expr (call .func , ctx )
279284 args = tuple (_expr (a , ctx ) for a in call .args )
@@ -342,17 +347,14 @@ def _eval_ast_expr(expr: ast.expr, ctx: _Context):
342347 return eval (code , dict (ctx .frozen_globals ), {})
343348
344349
345- def _parse_keyword_like_func (expr : ast .expr ,
346- kwd_funcs : tuple [Callable , ...],
347- kwd_func_names : tuple [str , ...],
348- ctx : _Context ) -> str | None :
350+ def _parse_keyword_like_func (expr : ast .expr , ctx : _Context ) -> str | None :
349351 if isinstance (expr , ast .Name ):
350352 if (expr .id not in ctx .local_names
351- and ctx .frozen_globals .get (expr .id ) in kwd_funcs ):
352- idx = kwd_funcs .index (ctx .frozen_globals .get (expr .id ))
353- return kwd_func_names [idx ]
353+ and ctx .frozen_globals .get (expr .id ) in _KEYWORD_LIKE_FUNCS ):
354+ idx = _KEYWORD_LIKE_FUNCS .index (ctx .frozen_globals .get (expr .id ))
355+ return _KEYWORD_LIKE_FUNC_NAMES [idx ]
354356 elif isinstance (expr , ast .Attribute ):
355- if expr .attr in kwd_func_names and _is_cuda_tile_module (expr .value , ctx ):
357+ if expr .attr in _KEYWORD_LIKE_FUNC_NAMES and _is_cuda_tile_module (expr .value , ctx ):
356358 return expr .attr
357359 return None
358360
@@ -583,21 +585,40 @@ def _for_stmt(stmt: ast.For, ctx: _Context):
583585 if len (stmt .orelse ) > 0 :
584586 raise ctx .syntax_error ("'for-else' is not supported" , loc = stmt .orelse [0 ])
585587
586- iterable = _expr (stmt .iter , ctx )
587- if not isinstance (stmt .target , ast .Name ):
588- raise ctx .unsupported_syntax (stmt .target )
588+ static_iter_expr = _get_static_iter_expr (stmt .iter , ctx )
589+ if static_iter_expr is None :
590+ kind = LoopKind .FOR
591+ op = hir .loop
592+ iterable = _expr (stmt .iter , ctx )
593+ else :
594+ kind = LoopKind .STATIC_FOR
595+ op = hir .static_foreach
596+ with ctx .change_loc (static_iter_expr ):
597+ iterable = _call_static_eval (static_iter_expr ,
598+ hir .StaticEvalKind .STATIC_ITER_ITERABLE , ctx )
589599
590- ctx .parent_loops .append (LoopKind . FOR )
600+ ctx .parent_loops .append (kind )
591601 induction_var = ctx .make_value ()
592602 with ctx .new_block (params = (induction_var ,)) as body_block :
593- with ctx .change_loc (stmt .target ):
594- ctx .store (stmt .target .id , induction_var )
603+ _do_assign (induction_var , stmt .target , ctx )
595604 _stmt_list (stmt .body , ctx )
596- if body_block .jump is None :
605+ if body_block .jump is None and static_iter_expr is None :
597606 ctx .set_block_jump (hir .Jump .CONTINUE )
598607 ctx .parent_loops .pop ()
599608
600- ctx .call_void (hir .loop , (body_block , iterable ))
609+ ctx .call_void (op , (body_block , iterable ))
610+
611+
612+ def _get_static_iter_expr (expr : ast .expr , ctx : _Context ) -> ast .expr | None :
613+ if not isinstance (expr , ast .Call ):
614+ return None
615+ if _parse_keyword_like_func (expr .func , ctx ) != "static_iter" :
616+ return None
617+
618+ if len (expr .args ) != 1 or len (expr .keywords ) != 0 :
619+ raise ctx .syntax_error ("static_iter() expects a single expression" )
620+
621+ return expr .args [0 ]
601622
602623
603624def _bool_expr (expr : ast .AST , ctx : _Context ) -> hir .Value :
@@ -724,19 +745,21 @@ def _if_stmt(stmt: ast.If, ctx: _Context) -> None:
724745
725746@_register (_stmt_handlers , ast .Continue )
726747def _continue_stmt (stmt : ast .Continue , ctx : _Context ) -> None :
748+ if ctx .parent_loops and ctx .parent_loops [- 1 ] is LoopKind .STATIC_FOR :
749+ raise ctx .syntax_error ("Continue in a for loop with static_iter() is not supported" )
727750 ctx .set_block_jump (hir .Jump .CONTINUE )
728751
729752
730753@_register (_stmt_handlers , ast .Break )
731754def _break_stmt (stmt : ast .Break , ctx : _Context ) -> None :
732- if ctx .parent_loops and ctx .parent_loops [- 1 ] is LoopKind .FOR :
755+ if ctx .parent_loops and ctx .parent_loops [- 1 ] in ( LoopKind .FOR , LoopKind . STATIC_FOR ) :
733756 raise ctx .syntax_error ("Break in a for loop is not supported" )
734757 ctx .set_block_jump (hir .Jump .BREAK )
735758
736759
737760@_register (_stmt_handlers , ast .Return )
738761def _return_stmt (stmt : ast .Return , ctx : _Context ) -> None :
739- if ctx .parent_loops and ctx .parent_loops [- 1 ] is LoopKind .FOR :
762+ if ctx .parent_loops and ctx .parent_loops [- 1 ] in ( LoopKind .FOR , LoopKind . STATIC_FOR ) :
740763 raise ctx .syntax_error ("Returning from a for loop is not supported" )
741764
742765 return_val = None if stmt .value is None else _expr (stmt .value , ctx )
0 commit comments