Skip to content

Commit 94f7824

Browse files
committed
Add ct.static_iter() & tuple concatenation
Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent f0be470 commit 94f7824

10 files changed

Lines changed: 450 additions & 30 deletions

File tree

changelog.d/static-iter.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- Added `ct.static_iter` keyword that enables compile-time `for` loops.
5+
- Operator `+` can now be used to concatenate tuples.

docs/source/operations.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,4 @@ Metaprogramming Support
199199

200200
static_assert
201201
static_eval
202+
static_iter

src/cuda/tile/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@
136136
sqrt,
137137
static_assert,
138138
static_eval,
139+
static_iter,
139140
store,
140141
sub,
141142
sum,
@@ -276,6 +277,7 @@
276277
"sqrt",
277278
"static_assert",
278279
"static_eval",
280+
"static_iter",
279281
"store",
280282
"sub",
281283
"sum",

src/cuda/tile/_dispatch_mode.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,14 @@ def __init__(self, kind: StaticEvalKind):
3737
self._kind = kind
3838

3939
def call_tile_function_from_host(self, func, args, kwargs):
40-
from cuda.tile import static_eval, static_assert
41-
if func in (static_eval, static_assert):
40+
from cuda.tile import static_eval, static_assert, static_iter
41+
if func in (static_eval, static_assert, static_iter):
4242
what = f"{func.__name__}() cannot be used"
4343
else:
44-
what = "Tile functions cannot be called"
44+
func_name = getattr(func, "__name__", "")
45+
if len(func_name) > 0:
46+
func_name = func_name + ": "
47+
what = f"{func_name}Tile functions cannot be called"
4548

4649
where = self._kind._value_
4750
raise TileStaticEvalError(f"{what} inside {where}.")

src/cuda/tile/_ir/hir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ class StaticEvalKind(enum.Enum):
228228
STATIC_EVAL = "static_eval()"
229229
STATIC_ASSERT_CONDITION = "static_assert() condition"
230230
STATIC_ASSERT_MESSAGE = "static_assert() message"
231+
STATIC_ITER_ITERABLE = "static_iter() iterable"
231232

232233

233234
@dataclass
@@ -238,6 +239,7 @@ class StaticEvalExpression:
238239

239240
def if_else(cond, then_block, else_block, /): ...
240241
def loop(body, iterable, /): ... # infinite if `iterable` is None
242+
def static_foreach(body, items, /): ...
241243
def build_tuple(*items): ... # Makes a tuple (i.e. returns `items`)
242244
def unpack(iterable, expected_len, /): ...
243245
def identity(x): ... # Identity function (i.e. returns `x`)

src/cuda/tile/_ir/ops.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
from typing_extensions import override
1414

1515
import cuda.tile._stub as ct
16-
from cuda.tile import _datatype as datatype, TileValueError, TileStaticEvalError
16+
from cuda.tile import _datatype as datatype
1717
from cuda.tile import RoundingMode, MemoryOrder, MemoryScope
1818
from cuda.tile._mutex import tile_mutex
19-
from cuda.tile._exception import TileTypeError, TileSyntaxError, TileError, TileStaticAssertionError
19+
from cuda.tile._exception import TileTypeError, TileSyntaxError, TileError, \
20+
TileStaticAssertionError, TileStaticEvalError, TileValueError
2021
from cuda.tile._ir.ir import (
2122
Operation, Var, Loc, Block,
2223
add_operation, Builder,
@@ -1008,7 +1009,6 @@ def binary_arithmetic(fn: str, x: Var, y: Var, rounding_mode: Optional[RoundingM
10081009
@impl(ct.floordiv, fixed_args=["floordiv"])
10091010
@impl(ct.cdiv, fixed_args=["cdiv"])
10101011
@impl(ct.pow, fixed_args=["pow"])
1011-
@impl(operator.add, fixed_args=["add"])
10121012
@impl(operator.sub, fixed_args=["sub"])
10131013
@impl(operator.mul, fixed_args=["mul"])
10141014
@impl(operator.floordiv, fixed_args=["floordiv"])
@@ -1020,6 +1020,15 @@ def binary_arithmetic_impl(fn: str, x: Var, y: Var) -> Var:
10201020
return binary_arithmetic(fn, x, y)
10211021

10221022

1023+
@impl(operator.add)
1024+
def add_impl(x: Var, y: Var) -> Var:
1025+
if isinstance(x.get_type(), TupleTy) and isinstance(y.get_type(), TupleTy):
1026+
x_items = x.get_aggregate().items
1027+
y_items = y.get_aggregate().items
1028+
return build_tuple(x_items + y_items)
1029+
return binary_arithmetic("add", x, y)
1030+
1031+
10231032
@impl(ct.minimum, fixed_args=["min"])
10241033
@impl(ct.maximum, fixed_args=["max"])
10251034
def binary_arithmetic_impl_with_ftz(fn: str, x: Var, y: Var, flush_to_zero: Var) -> Var:
@@ -3901,6 +3910,12 @@ def static_assert_impl(condition: Var, message: Var):
39013910
" e.g. cuda.tile.static_assert() or ct.static_assert().")
39023911

39033912

3913+
@impl(ct.static_iter)
3914+
def static_iter_impl(iterable: Var):
3915+
raise TileSyntaxError("static_iter() must be used directly by name,"
3916+
" e.g. cuda.tile.static_iter() or ct.static_iter().")
3917+
3918+
39043919
@impl(hir.do_static_eval)
39053920
def do_static_eval_impl(expr: hir.StaticEvalExpression,
39063921
local_var_values: tuple[Var, ...]) -> Var:
@@ -3923,10 +3938,52 @@ def do_static_eval_impl(expr: hir.StaticEvalExpression,
39233938
if result is None:
39243939
result = ""
39253940
return loosely_typed_const(str(result))
3941+
elif expr.kind == hir.StaticEvalKind.STATIC_ITER_ITERABLE:
3942+
items = _drain_static_iter_iterable(result)
3943+
return build_tuple(tuple(items))
39263944
else:
39273945
return sym2var(result)
39283946

39293947

3948+
_STATIC_ITER_MAX_ITERATIONS = 1000
3949+
3950+
3951+
def _drain_static_iter_iterable(iterable) -> list[Var]:
3952+
try:
3953+
it = iter(iterable)
3954+
except Exception as e:
3955+
msg = str(e)
3956+
if len(msg) > 0:
3957+
msg = ": " + msg
3958+
raise TileTypeError(f"Invalid static_iter() iterable{msg}")
3959+
3960+
items = []
3961+
for i in range(_STATIC_ITER_MAX_ITERATIONS + 1):
3962+
try:
3963+
x = next(it)
3964+
except StopIteration:
3965+
break
3966+
except Exception as e:
3967+
msg = str(e)
3968+
if len(msg) > 0:
3969+
msg = ": " + msg
3970+
raise TileTypeError(f"Error was raised while obtaining item #{i}"
3971+
f" from the static_iter() iterable{msg}")
3972+
3973+
try:
3974+
var = sym2var(x)
3975+
except TileTypeError as e:
3976+
raise TileStaticEvalError(
3977+
f"Invalid item #{i} of static_iter() iterable: {str(e)}")
3978+
3979+
items.append(var)
3980+
else:
3981+
raise TileStaticEvalError(f"Maximum number of iterations"
3982+
f" ({_STATIC_ITER_MAX_ITERATIONS}) has been reached"
3983+
f" while unpacking the static_iter() iterable")
3984+
return items
3985+
3986+
39303987
@impl(hir.do_static_assert)
39313988
async def do_static_assert_impl(condition: Var, message_block: hir.Block) -> None:
39323989
if not condition.is_constant():
@@ -3951,6 +4008,19 @@ async def do_static_assert_impl(condition: Var, message_block: hir.Block) -> Non
39514008
raise TileStaticAssertionError(message)
39524009

39534010

4011+
@impl(hir.static_foreach)
4012+
async def static_foreach_impl(body: hir.Block, items: Var):
4013+
scope = Scope.get_current()
4014+
4015+
tuple_val = items.get_aggregate()
4016+
assert isinstance(tuple_val, TupleValue)
4017+
4018+
for item in tuple_val.items:
4019+
scope.hir2ir_varmap[body.params[0].id] = item
4020+
from .._passes.hir2ir import dispatch_hir_block
4021+
await dispatch_hir_block(body)
4022+
4023+
39544024
def var2sym(var: Var) -> Any:
39554025
if var.is_constant():
39564026
return var.get_constant()

src/cuda/tile/_passes/ast2hir.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from cuda.tile._ir import hir
1818
from cuda.tile._ir.type import ClosureDefaultPlaceholder
1919
from cuda.tile._passes.ast_util import ast_get_all_local_names
20-
from cuda.tile._stub import static_eval, static_assert
20+
from cuda.tile._stub import static_eval, static_assert, static_iter
2121

2222

2323
@lru_cache
@@ -145,6 +145,7 @@ def _get_function_hir_inner(func_def: ast.FunctionDef | ast.Lambda, signature: i
145145

146146
class LoopKind(Enum):
147147
FOR = auto()
148+
STATIC_FOR = auto()
148149
WHILE = auto()
149150

150151

@@ -247,12 +248,13 @@ def decorate(f):
247248
_expr_handlers: Dict[Type[ast.AST], Callable] = {}
248249

249250

251+
_KEYWORD_LIKE_FUNCS = (static_eval, static_assert, static_iter)
252+
_KEYWORD_LIKE_FUNC_NAMES = ("static_eval", "static_assert", "static_iter")
253+
254+
250255
@_register(_expr_handlers, ast.Call)
251256
def _call_expr(call: ast.Call, ctx: _Context) -> hir.Value:
252-
kwd_func = _parse_keyword_like_func(call.func,
253-
(static_eval, static_assert),
254-
("static_eval", "static_assert"),
255-
ctx)
257+
kwd_func = _parse_keyword_like_func(call.func, ctx)
256258
if kwd_func is not None:
257259
if kwd_func == "static_eval":
258260
if len(call.args) != 1 or len(call.keywords) != 0:
@@ -272,8 +274,11 @@ def _call_expr(call: ast.Call, ctx: _Context) -> hir.Value:
272274
condition = _call_static_eval(call.args[0],
273275
hir.StaticEvalKind.STATIC_ASSERT_CONDITION, ctx)
274276
return ctx.call(hir.do_static_assert, (condition, message_block))
277+
elif kwd_func == "static_iter":
278+
raise TileSyntaxError("static_iter() is only allowed as iterable in a `for` loop,"
279+
" i.e. `for i in ct.static_iter(...)`")
275280
else:
276-
assert False
281+
raise TileSyntaxError(f"{kwd_func} is not expected here")
277282
else:
278283
callee = _expr(call.func, ctx)
279284
args = tuple(_expr(a, ctx) for a in call.args)
@@ -342,17 +347,14 @@ def _eval_ast_expr(expr: ast.expr, ctx: _Context):
342347
return eval(code, dict(ctx.frozen_globals), {})
343348

344349

345-
def _parse_keyword_like_func(expr: ast.expr,
346-
kwd_funcs: tuple[Callable, ...],
347-
kwd_func_names: tuple[str, ...],
348-
ctx: _Context) -> str | None:
350+
def _parse_keyword_like_func(expr: ast.expr, ctx: _Context) -> str | None:
349351
if isinstance(expr, ast.Name):
350352
if (expr.id not in ctx.local_names
351-
and ctx.frozen_globals.get(expr.id) in kwd_funcs):
352-
idx = kwd_funcs.index(ctx.frozen_globals.get(expr.id))
353-
return kwd_func_names[idx]
353+
and ctx.frozen_globals.get(expr.id) in _KEYWORD_LIKE_FUNCS):
354+
idx = _KEYWORD_LIKE_FUNCS.index(ctx.frozen_globals.get(expr.id))
355+
return _KEYWORD_LIKE_FUNC_NAMES[idx]
354356
elif isinstance(expr, ast.Attribute):
355-
if expr.attr in kwd_func_names and _is_cuda_tile_module(expr.value, ctx):
357+
if expr.attr in _KEYWORD_LIKE_FUNC_NAMES and _is_cuda_tile_module(expr.value, ctx):
356358
return expr.attr
357359
return None
358360

@@ -583,21 +585,40 @@ def _for_stmt(stmt: ast.For, ctx: _Context):
583585
if len(stmt.orelse) > 0:
584586
raise ctx.syntax_error("'for-else' is not supported", loc=stmt.orelse[0])
585587

586-
iterable = _expr(stmt.iter, ctx)
587-
if not isinstance(stmt.target, ast.Name):
588-
raise ctx.unsupported_syntax(stmt.target)
588+
static_iter_expr = _get_static_iter_expr(stmt.iter, ctx)
589+
if static_iter_expr is None:
590+
kind = LoopKind.FOR
591+
op = hir.loop
592+
iterable = _expr(stmt.iter, ctx)
593+
else:
594+
kind = LoopKind.STATIC_FOR
595+
op = hir.static_foreach
596+
with ctx.change_loc(static_iter_expr):
597+
iterable = _call_static_eval(static_iter_expr,
598+
hir.StaticEvalKind.STATIC_ITER_ITERABLE, ctx)
589599

590-
ctx.parent_loops.append(LoopKind.FOR)
600+
ctx.parent_loops.append(kind)
591601
induction_var = ctx.make_value()
592602
with ctx.new_block(params=(induction_var,)) as body_block:
593-
with ctx.change_loc(stmt.target):
594-
ctx.store(stmt.target.id, induction_var)
603+
_do_assign(induction_var, stmt.target, ctx)
595604
_stmt_list(stmt.body, ctx)
596-
if body_block.jump is None:
605+
if body_block.jump is None and static_iter_expr is None:
597606
ctx.set_block_jump(hir.Jump.CONTINUE)
598607
ctx.parent_loops.pop()
599608

600-
ctx.call_void(hir.loop, (body_block, iterable))
609+
ctx.call_void(op, (body_block, iterable))
610+
611+
612+
def _get_static_iter_expr(expr: ast.expr, ctx: _Context) -> ast.expr | None:
613+
if not isinstance(expr, ast.Call):
614+
return None
615+
if _parse_keyword_like_func(expr.func, ctx) != "static_iter":
616+
return None
617+
618+
if len(expr.args) != 1 or len(expr.keywords) != 0:
619+
raise ctx.syntax_error("static_iter() expects a single expression")
620+
621+
return expr.args[0]
601622

602623

603624
def _bool_expr(expr: ast.AST, ctx: _Context) -> hir.Value:
@@ -724,19 +745,21 @@ def _if_stmt(stmt: ast.If, ctx: _Context) -> None:
724745

725746
@_register(_stmt_handlers, ast.Continue)
726747
def _continue_stmt(stmt: ast.Continue, ctx: _Context) -> None:
748+
if ctx.parent_loops and ctx.parent_loops[-1] is LoopKind.STATIC_FOR:
749+
raise ctx.syntax_error("Continue in a for loop with static_iter() is not supported")
727750
ctx.set_block_jump(hir.Jump.CONTINUE)
728751

729752

730753
@_register(_stmt_handlers, ast.Break)
731754
def _break_stmt(stmt: ast.Break, ctx: _Context) -> None:
732-
if ctx.parent_loops and ctx.parent_loops[-1] is LoopKind.FOR:
755+
if ctx.parent_loops and ctx.parent_loops[-1] in (LoopKind.FOR, LoopKind.STATIC_FOR):
733756
raise ctx.syntax_error("Break in a for loop is not supported")
734757
ctx.set_block_jump(hir.Jump.BREAK)
735758

736759

737760
@_register(_stmt_handlers, ast.Return)
738761
def _return_stmt(stmt: ast.Return, ctx: _Context) -> None:
739-
if ctx.parent_loops and ctx.parent_loops[-1] is LoopKind.FOR:
762+
if ctx.parent_loops and ctx.parent_loops[-1] in (LoopKind.FOR, LoopKind.STATIC_FOR):
740763
raise ctx.syntax_error("Returning from a for loop is not supported")
741764

742765
return_val = None if stmt.value is None else _expr(stmt.value, ctx)

src/cuda/tile/_stub.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2083,6 +2083,32 @@ def static_assert(condition, message=None, /):
20832083
"""
20842084

20852085

2086+
@function
2087+
def static_iter(iterable):
2088+
"""Iterates at compile time.
2089+
2090+
Can only be used as the iterable of a `for` loop::
2091+
2092+
for ... in ct.static_iter(...):
2093+
...
2094+
2095+
The surrounded expression is evaluated using the same rules as :py:func:`static_eval`:
2096+
it can reference global and local variables, and use the full Python syntax,
2097+
but must not perform any run-time operations.
2098+
2099+
The expression must return a Python iterable, whose length must not exceed some
2100+
pre-defined number of iterations (currently, 1000). Before any further processing is done,
2101+
the contents of the iterable are saved to a temporary list, and each item is checked
2102+
to be valid, as if it were a result of a :py:func:`static_eval` expression
2103+
(i.e., it must be a supported compile-time constant value or a proxy object
2104+
for a dynamic value such as a tile).
2105+
2106+
Finally, for each item of the iterable, the loop body is inlined, with the induction variable(s)
2107+
bound to the item. The `break`, `continue`, and `return` statements are not allowed
2108+
inside a `static_iter` loop.
2109+
"""
2110+
2111+
20862112
# ==== Private stubs ====
20872113

20882114

0 commit comments

Comments
 (0)