Skip to content

Commit 7fb3407

Browse files
committed
tiled view
Signed-off-by: Boyan Li <boyanl@nvidia.com>
1 parent 2713100 commit 7fb3407

File tree

14 files changed

+722
-54
lines changed

14 files changed

+722
-54
lines changed

changelog.d/tiled-view.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- Add `Array.tiled_view(tile_shape, padding_mode=...)` to create a tiled view of an array
2+
with a fixed tile shape and padding mode.
3+
- The `TiledView` object exposes properties `dtype`, `tile_shape`, and `num_tiles`.
4+
It supports `load` and `store` methods for tile access.

docs/source/data.rst

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,32 @@ A tile index ``(i, j, ...)`` with shape ``S`` refers to the elements of the arra
156156
When accessing the elements of an array using tile indices, the multidimensional memory layout of the array is used.
157157
To access the tile space with a different memory layout, use the `order` parameter of load/store operations.
158158

159+
.. _data-tiled-views:
160+
161+
Tiled Views
162+
-----------
163+
164+
A *tiled view* represents the |tile space| of a |global array|.
165+
166+
A tiled view's *num_tiles* is a tuple of integer values, each denoting the number of tiles of
167+
the corresponding dimension.
168+
The length of the *num_tiles* tuple equals the tile space's number of dimensions.
169+
The product of *num_tiles* values equals the total number of tiles in the tile space.
170+
171+
A tile in the tiled view can be loaded or stored using its corresponding tile index.
172+
173+
.. seealso::
174+
:ref:`cuda.tile.TiledView class documentation <data-tiled-view-cuda-tile-tiled-view>`
175+
176+
:meth:`Array.tiled_view`
177+
178+
.. toctree::
179+
:maxdepth: 2
180+
:hidden:
181+
182+
data/tiled_view
183+
184+
159185
Shape Broadcasting
160186
------------------
161187

docs/source/data/tiled_view.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
.. SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
..
3+
.. SPDX-License-Identifier: Apache-2.0
4+
5+
.. currentmodule:: cuda.tile
6+
7+
.. _data-tiled-view-cuda-tile-tiled-view:
8+
9+
cuda.tile.TiledView
10+
===================
11+
12+
.. autoclass:: TiledView
13+
:members:
14+
:undoc-members:
15+
:special-members:
16+
:exclude-members: __annotations__, __dict__, __module__, __weakref__

docs/source/references.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@
102102
.. |tile spaces| replace:: :ref:`tile spaces <data-element-tile-space>`
103103
.. |Tile spaces| replace:: :ref:`Tile spaces <data-element-tile-space>`
104104

105+
.. |tiled view| replace:: :ref:`tiled view <data-tiled-views>`
106+
.. |Tiled view| replace:: :ref:`Tiled view <data-tiled-views>`
107+
.. |tiled views| replace:: :ref:`tiled views <data-tiled-views>`
108+
.. |Tiled views| replace:: :ref:`Tiled views <data-tiled-views>`
109+
105110
.. |scalar| replace:: :ref:`scalar <data-tiles-and-scalars>`
106111
.. |Scalar| replace:: :ref:`Scalar <data-tiles-and-scalars>`
107112
.. |scalars| replace:: :ref:`scalars <data-tiles-and-scalars>`

src/cuda/tile/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
ConstantAnnotation,
6060
Scalar,
6161
Tile,
62+
TiledView,
6263

6364
abs,
6465
add,
@@ -201,6 +202,7 @@
201202
"ConstantAnnotation",
202203
"Scalar",
203204
"Tile",
205+
"TiledView",
204206

205207
"abs",
206208
"add",

src/cuda/tile/_ir/ir.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,14 @@ def as_tuple(self) -> tuple[Var, ...]:
271271
return self.base_ptr, *self.shape, *self.strides
272272

273273

274+
@dataclass
275+
class TiledViewValue(AggregateValue):
276+
array: Var
277+
278+
def as_tuple(self) -> tuple["Var", ...]:
279+
return (self.array,)
280+
281+
274282
@dataclass
275283
class ListValue(AggregateValue):
276284
base_ptr: Var

src/cuda/tile/_ir/op_impl.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from .typing_support import datatype, get_signature
2020
from .ir import Var, TupleValue, Builder
21-
from .type import TupleTy, TileTy, DTypeSpec, EnumTy, StringTy, ArrayTy, SliceType, \
21+
from .type import TiledViewTy, TupleTy, TileTy, DTypeSpec, EnumTy, StringTy, ArrayTy, SliceType, \
2222
ListTy, LooselyTypedScalar, RangeIterType, FunctionTy, ClosureTy, BoundMethodTy, \
2323
DTypeConstructor, Type
2424

@@ -358,6 +358,13 @@ def require_array_type(var: Var) -> ArrayTy:
358358
return ty
359359

360360

361+
def require_tiled_view_type(var: Var) -> TiledViewTy:
362+
ty = var.get_type()
363+
if not isinstance(ty, TiledViewTy):
364+
raise TileTypeError(f"Expected a tiled view, but given value has type {ty}")
365+
return ty
366+
367+
361368
def require_list_type(var: Var) -> ListTy:
362369
ty = var.get_type()
363370
if not isinstance(ty, ListTy):

src/cuda/tile/_ir/ops.py

Lines changed: 110 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
add_operation, Builder,
2424
enter_nested_block, nested_block, PhiState, LoopVarState,
2525
TupleValue, make_aggregate, RangeValue, BoundMethodValue, ArrayValue, ConstantState,
26-
ListValue, ClosureValue, MemoryEffect, attribute, operand, BlockRestriction
26+
ListValue, TiledViewValue, ClosureValue, MemoryEffect, attribute, operand, BlockRestriction
2727
)
2828
from .type import PointerTy
2929
from . import hir
@@ -33,8 +33,8 @@
3333
require_signed_integer_0d_tile_type,
3434
require_tile_type, normalize_axis, require_dtype_spec,
3535
require_constant_bool, require_optional_constant_enum,
36-
require_constant_str, require_array_type, require_tuple_type, require_constant_slice,
37-
require_list_type, require_0d_tile_type,
36+
require_constant_str, require_array_type, require_tiled_view_type, require_tuple_type,
37+
require_constant_slice, require_list_type, require_0d_tile_type,
3838
require_index_or_index_tuple_type, require_constant_shape, require_constant_axis_order,
3939
require_constant_enum, require_optional_constant_int, require_optional_constant_bool,
4040
require_optional_constant_str, PrintfValidator, require_tile_maybe_loose_type,
@@ -43,7 +43,7 @@
4343
require_callable_type)
4444
from .ops_utils import (
4545
BINOP_REGISTRY, UNARYOP_REGISTRY,
46-
check_rd_and_ftz, PaddingMode,
46+
check_rd_and_ftz, PaddingMode, get_default_order,
4747
rounding_mode_to_bytecode, get_default_rounding_mode, get_dtype,
4848
change_dtype, memory_order_to_bytecode,
4949
memory_scope_to_bytecode, broadcast_shapes2, is_shape_broadcastable_to, BroadcastError,
@@ -57,7 +57,7 @@
5757
PartitionViewTy, TupleTy, TileTy, NoneType, BoundMethodTy, ArrayTy,
5858
ListTy, make_tile_ty, SliceType, DTypeConstructor, RangeIterType, Type,
5959
NONE, ModuleTy, TypeTy, LooselyTypedScalar, DTypeSpec, StringTy, InvalidType,
60-
array_size_type, ClosureTy, LiveCapturedScope, TokenTy,
60+
array_size_type, ClosureTy, LiveCapturedScope, TokenTy, TiledViewTy
6161
)
6262
from cuda.tile._datatype import (
6363
DType, is_integral, is_float, is_signed, is_boolean, is_restricted_float,
@@ -1552,6 +1552,7 @@ def getattr_impl(object: Var, name: Var) -> Var:
15521552
case ArrayTy(), "shape": return build_tuple(object.get_aggregate().shape)
15531553
case ArrayTy(), "strides": return build_tuple(object.get_aggregate().strides)
15541554
case ArrayTy(), "slice": return bind_method(object, ct._m_array_slice)
1555+
case ArrayTy(), "tiled_view": return bind_method(object, ct._m_array_tiled_view)
15551556

15561557
case TileTy(), "dtype": return loosely_typed_const(ty.dtype)
15571558
case TileTy(), "shape": return loosely_typed_const(ty.shape)
@@ -1564,6 +1565,15 @@ def getattr_impl(object: Var, name: Var) -> Var:
15641565
case TileTy(), "transpose": return bind_method(object, ct.transpose)
15651566
case TileTy(), "item": return bind_method(object, ct._m_tile_item)
15661567

1568+
case TiledViewTy(), "dtype": return loosely_typed_const(ty.dtype)
1569+
case TiledViewTy(), "tile_shape": return loosely_typed_const(ty.tile_shape)
1570+
case TiledViewTy(), "num_tiles":
1571+
[array] = object.get_aggregate().as_tuple()
1572+
return build_tuple(num_tiles(array, ty.tile_shape, get_default_order(ty.ndim)))
1573+
1574+
case TiledViewTy(), "load": return bind_method(object, ct._m_tiled_view_load)
1575+
case TiledViewTy(), "store": return bind_method(object, ct._m_tiled_view_store)
1576+
15671577
case ModuleTy(), _:
15681578
try:
15691579
return loosely_typed_const(getattr(ty.py_mod, attr_name))
@@ -2087,32 +2097,38 @@ def generate_bytecode(self, ctx: BytecodeContext) -> tuple[bc.Value, bc.Value]:
20872097
return res, res_token
20882098

20892099

2100+
def _tile_load_impl_inner(array: Var, index_items: tuple[Var, ...], shape: Sequence[int],
2101+
order: Sequence[int], padding_mode: PaddingMode,
2102+
latency: Var, allow_tma: Var) -> Var:
2103+
array_ty = require_array_type(array)
2104+
broadcasted_shape = (1,) * array_ty.ndim if len(shape) == 0 else shape
2105+
latency = require_optional_constant_int(latency)
2106+
allow_tma = require_optional_constant_bool(allow_tma)
2107+
_check_load_store_hints(latency, allow_tma)
2108+
2109+
view = make_partition_view(array, broadcasted_shape, order, padding_mode)
2110+
res_ty = make_tile_ty(array_ty.dtype, broadcasted_shape)
2111+
result, _token = add_operation(TileLoad, (res_ty, TokenTy()),
2112+
view=view, index=index_items, latency=latency,
2113+
allow_tma=allow_tma)
2114+
return reshape(result, shape)
2115+
2116+
20902117
@impl(ct.load)
20912118
def tile_load_impl(array: Var, index: Var, shape: Var, order: Var,
20922119
padding_mode: Var, latency: Var, allow_tma: Var) -> Var:
20932120
array_ty = require_array_type(array)
20942121
index_ty = require_index_or_index_tuple_type(index)
20952122
index_items = index.get_aggregate().items if isinstance(index_ty, TupleTy) else (index,)
2096-
20972123
if array_ty.ndim != len(index_items):
20982124
raise TileTypeError(f"Index size {len(index_items)}"
20992125
f" does not match the array rank {array_ty.ndim}")
21002126

21012127
shape = require_constant_shape(shape, allow_single_int=True, expected_rank=array_ty.ndim,
21022128
allow_0d_shape=True)
2103-
broadcasted_shape = (1,) * array_ty.ndim if len(shape) == 0 else shape
21042129
order = require_constant_axis_order(order, array_ty.ndim)
21052130
padding_mode = require_constant_enum(padding_mode, PaddingMode)
2106-
latency = require_optional_constant_int(latency)
2107-
allow_tma = require_optional_constant_bool(allow_tma)
2108-
_check_load_store_hints(latency, allow_tma)
2109-
2110-
view = make_partition_view(array, broadcasted_shape, order, padding_mode)
2111-
res_ty = make_tile_ty(array_ty.dtype, broadcasted_shape)
2112-
result, _token = add_operation(TileLoad, (res_ty, TokenTy()),
2113-
view=view, index=index_items, latency=latency,
2114-
allow_tma=allow_tma)
2115-
return reshape(result, shape)
2131+
return _tile_load_impl_inner(array, index_items, shape, order, padding_mode, latency, allow_tma)
21162132

21172133

21182134
@dataclass(eq=False)
@@ -2150,29 +2166,34 @@ def _implicit_cast(src: Var, target_dtype: DType, error_context: str) -> Var:
21502166
return astype(src, target_dtype)
21512167

21522168

2169+
def _tile_store_impl_inner(array: Var, index_items: tuple[Var, ...], tile: Var,
2170+
order: Sequence[int], latency: Var, allow_tma: Var):
2171+
array_ty = require_array_type(array)
2172+
tile_ty = require_tile_type(tile)
2173+
broadcasted_shape = (1,) * array_ty.ndim if len(tile_ty.shape) == 0 else tile_ty.shape
2174+
latency = require_optional_constant_int(latency)
2175+
allow_tma = require_optional_constant_bool(allow_tma)
2176+
_check_load_store_hints(latency, allow_tma)
2177+
2178+
tile = reshape(tile, broadcasted_shape)
2179+
view = make_partition_view(array, broadcasted_shape, order, PaddingMode.UNDETERMINED)
2180+
[_token] = add_operation(TileStore, (TokenTy(),), view=view, index=index_items, tile=tile,
2181+
latency=latency, allow_tma=allow_tma)
2182+
2183+
21532184
@impl(ct.store)
21542185
def tile_store_impl(array: Var, index: Var, tile: Var, order: Var,
21552186
latency: Var, allow_tma: Var):
21562187
array_ty = require_array_type(array)
2157-
tile_ty = require_tile_type(tile)
21582188
index_ty = require_index_or_index_tuple_type(index)
21592189
index_items = index.get_aggregate().items if isinstance(index_ty, TupleTy) else (index,)
21602190
if array_ty.ndim != len(index_items):
21612191
raise TileTypeError(f"Index size {len(index_items)}"
21622192
f" does not match the array rank {array_ty.ndim}")
21632193

2164-
shape = tile_ty.shape
2165-
broadcasted_shape = (1,) * array_ty.ndim if len(shape) == 0 else shape
2166-
order = require_constant_axis_order(order, array_ty.ndim)
2167-
latency = require_optional_constant_int(latency)
2168-
allow_tma = require_optional_constant_bool(allow_tma)
2169-
_check_load_store_hints(latency, allow_tma)
2170-
21712194
tile = _implicit_cast(tile, array_ty.dtype, "Stored tile is incompatible with array's dtype")
2172-
tile = reshape(tile, broadcasted_shape)
2173-
view = make_partition_view(array, broadcasted_shape, order, PaddingMode.UNDETERMINED)
2174-
[_token] = add_operation(TileStore, (TokenTy(),), view=view, index=index_items, tile=tile,
2175-
latency=latency, allow_tma=allow_tma)
2195+
order = require_constant_axis_order(order, array_ty.ndim)
2196+
_tile_store_impl_inner(array, index_items, tile, order, latency, allow_tma)
21762197

21772198

21782199
@dataclass(eq=False)
@@ -2646,16 +2667,22 @@ def join_tokens(tokens: Tuple[Var, ...], *, block: Block, res: Var, loc: Loc) ->
26462667

26472668
@dataclass(eq=False)
26482669
class NumTiles(Operation, opcode="num_tiles"):
2649-
axis: int = attribute()
26502670
view: Var = operand()
26512671

26522672
@override
26532673
def generate_bytecode(self, ctx: BytecodeContext):
26542674
view_ty: PartitionViewTy = self.view.get_type()
26552675
result_types = [ctx.type_table.tile(ctx.type_table.I32, ())] * len(view_ty.tile_shape)
2656-
values = bc.encode_GetIndexSpaceShapeOp(ctx.builder, result_types,
2657-
src=ctx.get_value(self.view))
2658-
return values[self.axis]
2676+
values = bc.encode_GetIndexSpaceShapeOp(ctx.builder, result_types, ctx.get_value(self.view))
2677+
return values
2678+
2679+
2680+
def num_tiles(array: Var, shape: Sequence[int], order: Sequence[int]) -> Tuple[Var, ...]:
2681+
array_ty = require_array_type(array)
2682+
broadcasted_shape = (1,) * array_ty.ndim if len(shape) == 0 else shape
2683+
view = make_partition_view(array, broadcasted_shape, order, PaddingMode.UNDETERMINED)
2684+
result_tys = tuple(make_tile_ty(datatype.default_int_type, ()) for _s in broadcasted_shape)
2685+
return add_operation(NumTiles, result_tys, view=view)
26592686

26602687

26612688
@impl(ct.num_tiles)
@@ -2665,12 +2692,9 @@ def num_tiles_impl(array: Var, axis: Var, shape: Var, order: Var) -> Var:
26652692
axis = normalize_axis(axis, array_ty.ndim)
26662693
shape = require_constant_shape(shape, allow_single_int=True, expected_rank=array_ty.ndim,
26672694
allow_0d_shape=True)
2668-
broadcasted_shape = (1,) * array_ty.ndim if len(shape) == 0 else shape
26692695
order = require_constant_axis_order(order, array_ty.ndim)
2670-
2671-
view = make_partition_view(array, broadcasted_shape, order, PaddingMode.UNDETERMINED)
2672-
return add_operation(NumTiles, make_tile_ty(datatype.default_int_type, ()), view=view,
2673-
axis=axis)
2696+
space_shape = num_tiles(array, shape, order)
2697+
return space_shape[axis]
26742698

26752699

26762700
def full_const(shape: Sequence[int], fill_value: int | float, dtype: DType) -> Var:
@@ -4009,6 +4033,54 @@ def tile_item(tile: Var) -> Var:
40094033
return reshape(tile, ())
40104034

40114035

4036+
@impl(ct._m_array_tiled_view)
4037+
def array_tiled_view_impl(array: Var, tile_shape: Var, padding_mode: Var) -> Var:
4038+
array_ty = require_array_type(array)
4039+
shape_val = require_constant_shape(tile_shape, allow_single_int=True,
4040+
expected_rank=array_ty.ndim,
4041+
allow_0d_shape=True)
4042+
padding_mode_val = require_constant_enum(padding_mode, PaddingMode)
4043+
view_ty = TiledViewTy(array_ty, shape_val, padding_mode_val)
4044+
return make_aggregate(TiledViewValue(array), view_ty)
4045+
4046+
4047+
@impl(ct._m_tiled_view_load)
4048+
def tiled_view_load_impl(tiled_view: Var, index: Var, latency: Var, allow_tma: Var) -> Var:
4049+
view_ty = require_tiled_view_type(tiled_view)
4050+
index_ty = require_index_or_index_tuple_type(index)
4051+
index_items = index.get_aggregate().items if isinstance(index_ty, TupleTy) else (index,)
4052+
if view_ty.ndim != len(index_items):
4053+
raise TileTypeError(f"Index size {len(index_items)}"
4054+
f" does not match the tiled view rank {view_ty.ndim}")
4055+
4056+
[array] = tiled_view.get_aggregate().as_tuple()
4057+
order = get_default_order(view_ty.ndim)
4058+
return _tile_load_impl_inner(array, index_items, view_ty.tile_shape, order,
4059+
view_ty.padding_mode, latency, allow_tma)
4060+
4061+
4062+
@impl(ct._m_tiled_view_store)
4063+
def tiled_view_store_impl(tiled_view: Var, index: Var, tile: Var, latency: Var, allow_tma: Var):
4064+
view_ty = require_tiled_view_type(tiled_view)
4065+
index_ty = require_index_or_index_tuple_type(index)
4066+
index_items = index.get_aggregate().items if isinstance(index_ty, TupleTy) else (index,)
4067+
if view_ty.ndim != len(index_items):
4068+
raise TileTypeError(f"Index size {len(index_items)}"
4069+
f" does not match the tiled view rank {view_ty.ndim}")
4070+
4071+
tile_ty = require_tile_type(tile)
4072+
if not is_shape_broadcastable_to(tile_ty.shape, view_ty.tile_shape):
4073+
raise TileTypeError(f"Tile shape {tile_ty.shape} is not broadcastable"
4074+
f" to the tiled view's tile shape {view_ty.tile_shape}")
4075+
4076+
tile = broadcast_to(tile, view_ty.tile_shape)
4077+
tile = _implicit_cast(tile, view_ty.dtype,
4078+
"Stored tile is incompatible with tiled view's dtype")
4079+
[array] = tiled_view.get_aggregate().as_tuple()
4080+
order = get_default_order(view_ty.ndim)
4081+
_tile_store_impl_inner(array, index_items, tile, order, latency, allow_tma)
4082+
4083+
40124084
def store_var(local_idx: int, value: Var, loc: Loc | None = None):
40134085
scope = Scope.get_current()
40144086
new_var = scope.local.redefine(local_idx, loc or Builder.get_current().loc)

src/cuda/tile/_ir/ops_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,3 +339,7 @@ def broadcast_shapes2(s1: Sequence[int], s2: Sequence[int]) -> Tuple[int, ...]:
339339

340340
def is_shape_broadcastable_to(src: Sequence[int], dst: Sequence[int]) -> bool:
341341
return len(src) <= len(dst) and all(x in (y, 1) for x, y in zip(reversed(src), reversed(dst)))
342+
343+
344+
def get_default_order(rank: int) -> tuple[int, ...]:
345+
return tuple(range(rank))

0 commit comments

Comments
 (0)