Skip to content

Commit 2b3b491

Browse files
committed
Unify scalars and 0d tiles
Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent c56fa24 commit 2b3b491

File tree

18 files changed

+325
-352
lines changed

18 files changed

+325
-352
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- Erased the distinction between scalars and zero-dimensional tiles.
5+
They are now completely interchangeable.

docs/source/data.rst

Lines changed: 86 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,38 +22,112 @@ An array-based model was chosen because:
2222

2323
Within |tile code|, only the types described in this section are supported.
2424

25+
2526
Global Arrays
2627
-------------
2728

28-
.. autoclass:: Array
29-
:no-members:
30-
:no-index:
29+
A *global array* (or *array*) is a container of elements of a specific |dtype|
30+
arranged in a logical multidimensional space.
31+
32+
Array's *shape* is a tuple of integer values, each denoting the length of
33+
the corresponding dimension.
34+
The length of the shape tuple equals the arrays's number of dimensions.
35+
The product of shape values equals the total logical number of elements in the array.
36+
37+
Arrays are stored in global memory using a *strided memory layout*: in addition to a shape,
38+
an array also has an equally sized tuple of *strides*. Strides determine the mapping of logical
39+
array indices to physical memory locations. For example, for a 3-dimensional `float32` array
40+
with strides `(s1, s2, s3)`, the memory address of the element at the logical index
41+
`(i1, i2, i3)` will be:
42+
43+
.. code-block::
44+
45+
base_addr + 4 * (s1 * i1 + s2 * i2 + s3 * i3),
46+
47+
where ``base_addr`` is the base address of the array and `4` is the byte size of a single `float32`
48+
element.
49+
50+
New arrays can only be allocated by the host, and passed to the tile kernel as arguments.
51+
|Tile code| can only create new views of existing arrays, for example using
52+
:meth:`Array.slice`. Like in Python, assigning an array object to another variable does not copy
53+
the underlying data, but creates another reference to the array object.
54+
55+
Any object that implements the |DLPack| interface or the |CUDA Array Interface|
56+
can be passed to the kernel as an argument. Example: |CuPy| arrays and |PyTorch| tensors.
3157

32-
.. seealso::
33-
:ref:`Complete cuda.tile.Array class documentation <data-array-cuda-tile-array>`
58+
If two or more array arguments are passed to the kernel, their memory storage must not overlap.
59+
Otherwise, behavior is undefined.
60+
61+
Array's shape can be queried using the :py:attr:`Array.shape` attribute, which
62+
returns a tuple of `int32` scalars. These scalars are non-constant, runtime values.
63+
Using `int32` makes the tile code more performant at the cost of limiting the maximum
64+
representable shape at 2,147,483,647 elements. This limitation will be lifted in the future.
65+
66+
67+
.. seealso::
68+
:ref:`cuda.tile.Array class documentation <data-array-cuda-tile-array>`
3469

3570
.. toctree::
3671
:maxdepth: 2
3772
:hidden:
3873

3974
data/array
4075

41-
Tiles
42-
-----
4376

44-
.. autoclass:: Tile
45-
:no-members:
46-
:no-index:
77+
.. _data-tiles-and-scalars:
78+
79+
Tiles and Scalars
80+
-----------------
81+
A *tile* is an immutable multidimensional collection of elements of a specific |dtype|.
82+
83+
Tile's *shape* is a tuple of integer values, each denoting the length of the corresponding dimension.
84+
The length of the shape tuple equals the tile's number of dimensions.
85+
The product of shape values equals the total number of elements in the tile.
86+
87+
The shape of a tile must be known at compile time. Each dimension of a tile must be a power of 2.
88+
89+
Tile's dtype and shape can be queried with the ``dtype`` and ``shape`` attributes, respectively.
90+
For example, if ``x`` is a `float32` tile, the expression ``x.dtype`` will return
91+
a compile-time constant equal to :py:data:`cuda.tile.float32`.
4792

48-
.. seealso::
49-
:ref:`Complete cuda.tile.Tile class documentation <data-tile-cuda-tile-tile>`
93+
A zero-dimensional tile is called a *scalar*. Such tile has exactly one element. The shape
94+
of a scalar is the empty tuple `()`. Numeric literals like `7` or `3.14` are treated as
95+
constant scalars, i.e. zero-dimensional tiles.
96+
97+
Since scalars are tiles, they slightly differ in behavior from Python's ``int``/``float`` objects.
98+
For example, they have ``dtype`` and ``shape`` attributes:
99+
100+
.. code-block:: python
101+
102+
a = 0
103+
# The following line will evaluate to cuda.tile.int32 in cuTile,
104+
# but would raise an AttributeError in Python:
105+
a.dtype
106+
107+
Tiles can only be used in |tile code|, not host code.
108+
The contents of a tile do not necessarily have a physical representation in memory.
109+
Non-scalar tiles can be created by loading from |global arrays| using functions such as
110+
:py:func:`cuda.tile.load` and :py:func:`cuda.tile.gather` or with |factory| functions
111+
such as :py:func:`cuda.tile.zeros`.
112+
113+
Tiles can also be stored into global arrays using functions such as :py:func:`cuda.tile.store`
114+
or :py:func:`cuda.tile.scatter`.
115+
116+
Only scalars (i.e. 0-dimensional tiles) can be used as |kernel| parameters.
117+
118+
Scalar constants are |loosely typed| by default, for example, a literal ``2`` or
119+
a constant attribute like ``Tile.ndim``, ``Tile.shape``, or ``Array.ndim``.
120+
121+
.. seealso::
122+
:ref:`cuda.tile.Tile class documentation <data-tile-cuda-tile-tile>`
50123

51124
.. toctree::
52125
:maxdepth: 2
53126
:hidden:
54127

55128
data/tile
56129

130+
57131
.. _data-element-tile-space:
58132

59133
Element & Tile Space
@@ -145,23 +219,6 @@ to a common dtype using the following process:
145219

146220
.. include:: generated/includes/dtype_promotion_table.rst
147221

148-
149-
.. _data-scalars:
150-
151-
Scalars
152-
-------
153-
154-
A *scalar* is a single immutable value of a specific |data type|. A *scalar* and *0D-tile*
155-
can be used interchangably in a tile |kernel|. They can also be |kernel| parameters.
156-
157-
Typing of a *scalar* has the following rules:
158-
159-
- Constant scalars are |loosely typed| by default, for example, a literal ``2`` or
160-
a constant property like ``Tile.ndim``, ``Tile.shape``, or ``Array.ndim``.
161-
- ``Array.shape`` and ``Array.stride`` are not constant by default and has default int type `int32`.
162-
Using default `int32` makes kernel more performant at the cost of limiting max representable shape.
163-
This limitation will be lifted in the near future.
164-
165222
Tuples
166223
------
167224

docs/source/references.rst

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@
6363
.. |arrays| replace:: :ref:`arrays <data-array-cuda-tile-array>`
6464
.. |Arrays| replace:: :ref:`Arrays <data-array-cuda-tile-array>`
6565

66-
.. |tile| replace:: :ref:`tile <data-tile-cuda-tile-tile>`
67-
.. |Tile| replace:: :ref:`Tile <data-tile-cuda-tile-tile>`
68-
.. |tiles| replace:: :ref:`tiles <data-tile-cuda-tile-tile>`
69-
.. |Tiles| replace:: :ref:`Tiles <data-tile-cuda-tile-tile>`
66+
.. |tile| replace:: :ref:`tile <data-tiles-and-scalars>`
67+
.. |Tile| replace:: :ref:`Tile <data-tiles-and-scalars>`
68+
.. |tiles| replace:: :ref:`tiles <data-tiles-and-scalars>`
69+
.. |Tiles| replace:: :ref:`Tiles <data-tiles-and-scalars>`
7070

7171
.. |data type| replace:: :ref:`data type <data-data-types>`
7272
.. |Data type| replace:: :ref:`Data type <data-data-types>`
@@ -102,10 +102,10 @@
102102
.. |tile spaces| replace:: :ref:`tile spaces <data-element-tile-space>`
103103
.. |Tile spaces| replace:: :ref:`Tile spaces <data-element-tile-space>`
104104

105-
.. |scalar| replace:: :ref:`scalar <data-scalars>`
106-
.. |Scalar| replace:: :ref:`Scalar <data-scalars>`
107-
.. |scalars| replace:: :ref:`scalars <data-scalars>`
108-
.. |Scalars| replace:: :ref:`Scalars <data-scalars>`
105+
.. |scalar| replace:: :ref:`scalar <data-tiles-and-scalars>`
106+
.. |Scalar| replace:: :ref:`Scalar <data-tiles-and-scalars>`
107+
.. |scalars| replace:: :ref:`scalars <data-tiles-and-scalars>`
108+
.. |Scalars| replace:: :ref:`Scalars <data-tiles-and-scalars>`
109109

110110
.. |Rounding Modes| replace:: :ref:`Rounding Modes <data-rounding-modes>`
111111
.. |Padding Modes| replace:: :ref:`Padding Modes <data-padding-modes>`

src/cuda/tile/_datatype.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"NumericDTypeCategories"]
2222

2323

24-
class DType(Type):
24+
class DType:
2525
"""A *data type* (or *dtype*) describes the type of the objects of an |array|, |tile|, or
2626
operation.
2727
@@ -57,6 +57,10 @@ def name(self):
5757
def __name__(self) -> str:
5858
return self._name
5959

60+
@function(host=True, tile=False)
61+
def __repr__(self):
62+
return f"<DType '{self._name}'>"
63+
6064
@function(host=True, tile=False)
6165
def __str__(self):
6266
return self._name

src/cuda/tile/_ir/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def get_type_allow_invalid(self) -> Type:
179179
raise TileInternalError(f"Type of variable {self.name} not found")
180180

181181
def set_type(self, ty: Type, force: bool = False):
182+
assert isinstance(ty, Type)
182183
if not force:
183184
assert self.name not in self.ctx.typemap
184185
self.ctx.typemap[self.name] = ty
@@ -207,6 +208,7 @@ def get_loose_type_allow_invalid(self) -> Type:
207208
return self.get_type_allow_invalid() if ty is None else ty
208209

209210
def set_loose_type(self, ty: Type, force: bool = False):
211+
assert isinstance(ty, Type)
210212
if not force:
211213
assert self.name not in self.ctx._loose_typemap
212214
self.ctx._loose_typemap[self.name] = ty

0 commit comments

Comments
 (0)