Skip to content

Commit cfd9f5f

Browse files
committed
astile
Signed-off-by: Boyan Li <boyanl@nvidia.com>
1 parent 8e640b1 commit cfd9f5f

14 files changed

Lines changed: 489 additions & 67 deletions

changelog.d/astile.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- Added `ct.astile` for creating a tile from a scalar (yielding a 0-d tile) or a (possibly nested) tuple of scalars whose nesting determines the tile's shape.

docs/source/operations.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@ Factory
3737
:nosignatures:
3838

3939
arange
40+
astile
4041
full
4142
ones
4243
zeros
4344

45+
4446
.. _operations-shape-dtype:
4547

4648
Shape & DType

src/cuda/tile/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
argmax,
7575
argmin,
7676
assert_,
77+
astile,
7778
astype,
7879
atan2,
7980
atomic_add,
@@ -235,6 +236,7 @@
235236
"argmax",
236237
"argmin",
237238
"assert_",
239+
"astile",
238240
"astype",
239241
"atan2",
240242
"atomic_add",

src/cuda/tile/_ir/op_impl.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -227,15 +227,15 @@ class _CurrentStub(threading.local):
227227
_current_stub = _CurrentStub()
228228

229229

230-
def _is_0d_tile(ty: Type, dtype_predicate: Callable[[DType], bool] = lambda _: True) -> bool:
230+
def is_0d_tile(ty: Type, dtype_predicate: Callable[[DType], bool] = lambda _: True) -> bool:
231231
return isinstance(ty, TileTy) and ty.ndim == 0 and dtype_predicate(ty.dtype)
232232

233233

234234
def require_constant_int(var: Var) -> int:
235235
if not var.is_constant():
236236
raise _make_type_error("Expected an integer constant, but given value is not constant", var)
237237
ty = var.get_type()
238-
if not _is_0d_tile(ty, is_integral):
238+
if not is_0d_tile(ty, is_integral):
239239
raise _make_type_error(f"Expected an integer constant, but given value has type {ty}",
240240
var)
241241
return var.get_constant()
@@ -251,14 +251,14 @@ def require_constant_bool(var: Var) -> bool:
251251
if not var.is_constant():
252252
raise _make_type_error("Expected a boolean constant, but given value is not constant", var)
253253
ty = var.get_type()
254-
if not _is_0d_tile(ty, is_boolean):
254+
if not is_0d_tile(ty, is_boolean):
255255
raise _make_type_error(f"Expected a boolean constant, but given value has type {ty}", var)
256256
return var.get_constant()
257257

258258

259259
def require_constant_scalar(var: Var) -> bool | int | float:
260260
ty = var.get_type()
261-
if not _is_0d_tile(ty):
261+
if not is_0d_tile(ty):
262262
raise _make_type_error(f"Expected a scalar constant, but given value has type {ty}", var)
263263
if not var.is_constant():
264264
raise _make_type_error(f"Expected a constant, but given value has non-constant type {ty}",
@@ -274,7 +274,7 @@ def require_constant_scalar_tuple(var: Var) -> tuple[bool | int | float, ...]:
274274
tuple_val = var.get_aggregate()
275275
assert isinstance(tuple_val, TupleValue)
276276
for i, (item_ty, item) in enumerate(zip(ty.value_types, tuple_val.items, strict=True)):
277-
if not _is_0d_tile(item_ty):
277+
if not is_0d_tile(item_ty):
278278
raise _make_type_error(f"Expected a tuple of scalar constants,"
279279
f" but item at position #{i} has type {item_ty}", var)
280280
if not item.is_constant():
@@ -323,6 +323,12 @@ def require_dtype_spec(var: Var) -> DType:
323323
return ty.dtype
324324

325325

326+
def require_optional_dtype_spec(var: Var) -> DType | None:
327+
if var.is_constant() and var.get_constant() is None:
328+
return None
329+
return require_dtype_spec(var)
330+
331+
326332
def require_optional_constant_enum(var: Var, enum: EnumMeta):
327333
if var.is_constant() and var.get_constant() is None:
328334
return None
@@ -363,14 +369,14 @@ def require_constant_int_tuple(var: Var, allow_single_int: bool = False) -> Tupl
363369
" but given value is not constant", var)
364370

365371
ty = var.get_type()
366-
if allow_single_int and _is_0d_tile(ty):
372+
if allow_single_int and is_0d_tile(ty):
367373
return require_constant_int(var),
368374

369375
if not isinstance(ty, TupleTy):
370376
raise _make_type_error(f"Expected a tuple, but given value has type {ty}", var)
371377

372378
for i, item_ty in enumerate(ty.value_types):
373-
if not _is_0d_tile(item_ty, is_integral):
379+
if not is_0d_tile(item_ty, is_integral):
374380
raise _make_type_error(f"Expected a tuple of integers,"
375381
f" but element #{i} has type {item_ty}", var)
376382

@@ -472,7 +478,7 @@ def require_signed_integer_0d_tile_type(var: Var) -> TileTy:
472478

473479
def require_bool(var: Var) -> TileTy:
474480
ty = var.get_type()
475-
if not _is_0d_tile(ty, is_boolean):
481+
if not is_0d_tile(ty, is_boolean):
476482
raise _make_type_error(f"Expected a bool, but given value has type {ty}", var)
477483
return ty
478484

src/cuda/tile/_ir/ops.py

Lines changed: 103 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import builtins
55
import dataclasses
66
import enum
7+
import functools
78
import math
89
import operator
910
from contextlib import contextmanager
@@ -19,7 +20,7 @@
1920
from cuda.tile import _datatype as datatype
2021
from cuda.tile import RoundingMode, MemoryOrder, MemoryScope
2122
from cuda.tile._mutex import tile_mutex
22-
from cuda.tile._exception import TileTypeError, TileSyntaxError, TileError, \
23+
from cuda.tile._exception import TileInternalError, TileTypeError, TileSyntaxError, TileError, \
2324
TileStaticAssertionError, TileStaticEvalError, TileValueError, TileUnsupportedFeatureError
2425
from cuda.tile._ir.ir import (
2526
Operation, Var, Loc, Block, add_operation, Builder, enter_nested_block, nested_block,
@@ -34,7 +35,8 @@
3435
from . import hir, hir_stubs
3536
from .hir import ResolvedName
3637
from .op_impl import (
37-
ImplRegistry, require_constant_int, require_constant_int_tuple,
38+
ImplRegistry, is_0d_tile, require_constant_int, require_constant_int_tuple,
39+
require_optional_dtype_spec,
3840
require_signed_integer_0d_tile_type,
3941
require_tile_type, normalize_axis, require_dtype_spec,
4042
require_constant_bool, require_optional_constant_enum,
@@ -3283,14 +3285,14 @@ def num_tiles_impl(array: Var, axis: Var, shape: Var, order: Var) -> Var:
32833285
return space_shape[axis]
32843286

32853287

3286-
def full_const(shape: Sequence[int], fill_value: int | float, dtype: DType) -> Var:
3288+
def _const(shape: Sequence[int], value: int | float | bool | tuple, dtype: DType) -> Var:
32873289
res_ty = make_tile_ty(dtype, shape)
3288-
return strictly_typed_const(fill_value, res_ty)
3290+
return strictly_typed_const(value, res_ty)
32893291

32903292

32913293
def full(shape: Sequence[int], fill_value: Var, dtype: DType) -> Var:
32923294
if fill_value.is_constant():
3293-
return full_const(shape, fill_value.get_constant(), dtype)
3295+
return _const(shape, fill_value.get_constant(), dtype)
32943296
fill_value = astype(fill_value, dtype)
32953297
return broadcast_to(fill_value, shape)
32963298

@@ -3307,14 +3309,93 @@ def full_impl(shape: Var, fill_value: Var, dtype: Var) -> Var:
33073309
def ones_impl(shape: Var, dtype: Var) -> Var:
33083310
shape = require_constant_shape(shape, allow_single_int=True)
33093311
dtype = require_dtype_spec(dtype)
3310-
return full_const(shape, 1, dtype)
3312+
return _const(shape, 1, dtype)
33113313

33123314

33133315
@impl(ct.zeros)
33143316
def zeros_impl(shape: Var, dtype: Var) -> Var:
33153317
shape = require_constant_shape(shape, allow_single_int=True)
33163318
dtype = require_dtype_spec(dtype)
3317-
return full_const(shape, 0, dtype)
3319+
return _const(shape, 0, dtype)
3320+
3321+
3322+
def _path_str(path: tuple[int, ...]) -> str:
3323+
return "value" + "".join(f"[{i}]" for i in path)
3324+
3325+
3326+
def _tuple_shape(ty: Type, path: tuple[int, ...]) -> tuple[int, ...]:
3327+
path_str = _path_str(path)
3328+
if not isinstance(ty, TupleTy):
3329+
if not isinstance(ty, TileTy):
3330+
raise TileTypeError(
3331+
f"Expected scalar elements at {path_str}; "
3332+
f"got element of type {ty}")
3333+
3334+
if ty.ndim != 0:
3335+
raise TileTypeError(
3336+
f"Expected scalar elements at {path_str}; "
3337+
f"got a tile of shape {ty.shape}")
3338+
3339+
assert is_0d_tile(ty)
3340+
return ()
3341+
3342+
n = len(ty)
3343+
if not _is_power_of_2(n):
3344+
raise TileTypeError(f"Tuple length {n} at {path_str} is not a power of 2")
3345+
3346+
inner_shapes = {_tuple_shape(t, path + (i,)) for i, t in enumerate(ty.value_types)}
3347+
if len(inner_shapes) != 1:
3348+
raise TileTypeError(f"Tuple has non-uniform inner shapes at {path_str}")
3349+
3350+
return (n,) + inner_shapes.pop()
3351+
3352+
3353+
def _flatten_tuple(value: Var) -> tuple[Var, ...]:
3354+
value_ty = value.get_type()
3355+
if not isinstance(value_ty, TupleTy):
3356+
return (value,)
3357+
return sum((_flatten_tuple(i) for i in value.get_aggregate().items), start=())
3358+
3359+
3360+
def _cat_tuple(tiles: tuple[Var, ...]) -> Var:
3361+
if len(tiles) == 0:
3362+
raise TileInternalError("Expected non-empty tile tuple")
3363+
3364+
if len(tiles) == 1:
3365+
require_0d_tile_type(tiles[0])
3366+
return reshape(tiles[0], (1,))
3367+
3368+
assert len(tiles) % 2 == 0
3369+
mid = len(tiles) // 2
3370+
left = _cat_tuple(tiles[:mid])
3371+
right = _cat_tuple(tiles[mid:])
3372+
return cat((left, right), axis=0)
3373+
3374+
3375+
@impl(ct.astile)
3376+
def astile_impl(value: Var, dtype: Var) -> Var:
3377+
dtype: Optional[DType] = require_optional_dtype_spec(dtype)
3378+
value_ty = value.get_type()
3379+
if is_0d_tile(value_ty):
3380+
return value if dtype is None else astype(value, dtype)
3381+
3382+
if not isinstance(value_ty, TupleTy):
3383+
raise TileTypeError(
3384+
f"Expected a scalar or (possibly nested) tuple of scalars; "
3385+
f"got value of type {value_ty}")
3386+
3387+
shape = _tuple_shape(value_ty, path=())
3388+
tiles = _flatten_tuple(value)
3389+
dtype = (functools.reduce(promote_dtypes, (require_0d_tile_type(t).dtype for t in tiles))
3390+
if dtype is None
3391+
else dtype)
3392+
3393+
if value.is_constant():
3394+
return _const(shape, value.get_constant(), dtype)
3395+
3396+
tiles = tuple(astype(t, dtype) for t in tiles)
3397+
flat = _cat_tuple(tiles)
3398+
return reshape(flat, shape)
33183399

33193400

33203401
_TileShape = Tuple[int, ...]
@@ -4116,29 +4197,24 @@ def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
41164197
return bc.encode_CatOp(ctx.builder, return_type_id, x_value, y_value, self.axis)
41174198

41184199

4119-
def cat(tiles: Var, axis: int) -> Var:
4120-
tuple_ty = require_tuple_type(tiles)
4121-
items = tiles.get_aggregate().items
4122-
if len(items) == 1:
4123-
return items[0]
4124-
4125-
if len(tuple_ty) == 0:
4200+
def cat(tiles: tuple[Var, ...], axis: int) -> Var:
4201+
if len(tiles) == 0:
41264202
raise TileTypeError("cat() received an empty tuple")
4127-
elif len(items) == 1:
4128-
return items[0]
4129-
elif len(tuple_ty) > 2:
4130-
raise TileTypeError(f"cat() supports at most 2 tiles, got {len(tuple_ty)}")
4203+
if len(tiles) == 1:
4204+
return tiles[0]
4205+
if len(tiles) > 2:
4206+
raise TileTypeError(f"cat() supports at most 2 tiles, got {len(tiles)}")
41314207

4132-
x_tile, y_tile = items
4208+
x_tile, y_tile = tiles
41334209

4134-
if not isinstance(first_tile := tuple_ty.value_types[0], TileTy):
4135-
raise TileTypeError(f"Expected tuple of Tile, got a {first_tile}")
4210+
if not isinstance(first_tile_ty := tiles[0].get_type(), TileTy):
4211+
raise TileTypeError(f"Expected tuple of Tile, got a {first_tile_ty}")
41364212

4137-
dtype = first_tile.dtype
4138-
rank = first_tile.ndim
4139-
shape_value = list(first_tile.shape)
4213+
dtype = first_tile_ty.dtype
4214+
rank = first_tile_ty.ndim
4215+
shape_value = list(first_tile_ty.shape)
41404216
axis = normalize_axis(axis, rank)
4141-
for tile_ty in tuple_ty.value_types[1:]:
4217+
for tile_ty in (t.get_type() for t in tiles[1:]):
41424218
if not isinstance(tile_ty, TileTy):
41434219
raise TileTypeError(f"Expected tuple of Tile, got a {tile_ty}")
41444220
if tile_ty.ndim != rank:
@@ -4167,8 +4243,9 @@ def _is_power_of_2(x: int):
41674243

41684244
@impl(ct.cat)
41694245
def cat_impl(tiles: Var, axis: Var) -> Var:
4246+
require_tuple_type(tiles)
41704247
const_axis = require_constant_int(axis)
4171-
return cat(tiles, const_axis)
4248+
return cat(tiles.get_aggregate().items, const_axis)
41724249

41734250

41744251
# Does not support broadcasting or type promotion

src/cuda/tile/_ir/typing_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def get_constant_value(val: Any) -> Any:
190190
return val
191191
if is_dtype(val):
192192
return to_dtype(val)
193-
if isinstance(val, tuple) and not any(isinstance(x, tuple) for x in val):
193+
if isinstance(val, tuple):
194194
return tuple(get_constant_value(x) for x in val)
195195
typ = type(val)
196196
prefix = "" if typ.__module__ == "builtins" else f"{typ.__module__}."

src/cuda/tile/_ir2bytecode.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525
from cuda.tile._ir.type import (
2626
PartitionViewTy, StridedViewTy, GatherScatterViewTy, Type, TileTy, PointerTy, TokenTy,
27-
TupleTy, ArrayTy, size_to_bytecode,
27+
ArrayTy, size_to_bytecode,
2828
)
2929

3030

@@ -103,6 +103,31 @@ def _constant_to_bytes(value: int | float, dtype: DType) -> bytes:
103103
raise TypeError(f"Cannot make a constant out of {dtype}")
104104

105105

106+
# Encode a potentially nested constant tuple as the raw row-major byte buffer according to MLIR's
107+
# "DenseElementsAttr" non-splat format.
108+
def _constant_tuple_to_bytes(value, dtype: DType, shape: tuple[int, ...]) -> bytes:
109+
# Note that MLIR requires bit packing for non-splat DenseElementsAttr<i1>
110+
if dtype == datatype.bool_ and isinstance(value, tuple):
111+
flat_bools = _flatten_bools(value)
112+
bits = 0
113+
for i, v in enumerate(flat_bools):
114+
if v:
115+
bits |= 1 << i
116+
return bits.to_bytes((len(flat_bools) + 7) // 8, "little")
117+
118+
if len(shape) == 0:
119+
return _constant_to_bytes(value, dtype)
120+
121+
assert len(value) == shape[0]
122+
return b"".join(_constant_tuple_to_bytes(c, dtype, shape[1:]) for c in value)
123+
124+
125+
def _flatten_bools(value) -> tuple[bool]:
126+
if not isinstance(value, tuple):
127+
return (bool(value),)
128+
return sum((_flatten_bools(v) for v in value), start=())
129+
130+
106131
def _get_type_conversion_encoder(from_dtype: Type, to_dtype: Type):
107132

108133
def kind(t):
@@ -401,20 +426,23 @@ def bitcast(self, value: bc.Value, fromty: Type, toty: Type) -> bc.Value:
401426
value = bc.encode_BitcastOp(self.builder, typeid(self.type_table, toty), value)
402427
return value
403428

404-
def constant(self, value: int | float, ty: Type) -> bc.Value:
405-
if isinstance(ty, TileTy):
406-
dtype = ty.dtype
407-
else:
408-
raise TypeError(f"Cannot make a constant tuple out of {ty}")
429+
def constant(self, value, ty: Type) -> bc.Value:
430+
if not isinstance(ty, TileTy):
431+
raise TypeError(f"Cannot encode a constant of type {ty}; expected a TileTy")
409432

410-
data = _constant_to_bytes(value, dtype)
411-
return bc.encode_ConstantOp(self.builder, typeid(self.type_table, ty), data)
433+
def get_numel(v):
434+
if not isinstance(v, tuple):
435+
return 1
436+
return sum(get_numel(i) for i in v)
412437

413-
def constant_tuple(self, value, ty: Type) -> Tuple[bc.Value, ...]:
414-
if isinstance(ty, TupleTy):
415-
return sum((self.constant_tuple(item_val, item_ty)
416-
for item_ty, item_val in zip(ty.value_types, value, strict=True)), ())
417-
return self.constant(value, ty),
438+
if get_numel(value) == 1:
439+
while isinstance(value, tuple):
440+
value = value[0]
441+
data = _constant_to_bytes(value, ty.dtype)
442+
else:
443+
assert isinstance(value, tuple)
444+
data = _constant_tuple_to_bytes(value, ty.dtype, ty.shape)
445+
return bc.encode_ConstantOp(self.builder, typeid(self.type_table, ty), data)
418446

419447
def index_tuple(self,
420448
index: tuple[Var, ...], *, keep_i64: bool = False) -> Tuple[bc.Value, ...]:

0 commit comments

Comments
 (0)