Skip to content

Commit eed86ef

Browse files
committed
add docs
Signed-off-by: Tim Gymnich <tim@gymni.ch>
1 parent 283d76b commit eed86ef

2 files changed

Lines changed: 168 additions & 0 deletions

File tree

docs/index.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ API reference material
2727
wave/wave
2828

2929

30+
Design documentation
31+
====================
32+
33+
.. toctree::
34+
:maxdepth: 1
35+
:caption: IR Design
36+
37+
ir_design
38+
3039
Project documentation
3140
=====================
3241

docs/ir_design.rst

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
Vector Shapes and Hardware Constraints
2+
======================================
3+
4+
This document describes the ``vector_shapes`` field on
5+
``#wave.hardware_constraint`` and how it relates to ``mma_type``,
6+
``elements_per_thread``, and the constraint system in the Water IR.
7+
8+
9+
Overview
10+
--------
11+
12+
``vector_shapes`` is an optional ``DictionaryAttr`` on
13+
``#wave.hardware_constraint``. Each entry maps a dimension name (a string
14+
matching a ``#wave.symbol``) to an integer specifying how many elements a
15+
single wave processes along that dimension in one instance of an operation
16+
before expansion has replicated it.
17+
18+
.. code-block:: mlir
19+
20+
#wave.hardware_constraint<
21+
threads_per_wave = 64,
22+
waves_per_block = [2, 2, 1],
23+
mma_type = #wave.mma_kind<f32_16x16x16_f16>,
24+
vector_shapes = {M = 16, N = 16, K = 16},
25+
max_bits_per_load = 128>
26+
27+
``vector_shapes`` is the central piece of information the compiler uses to:
28+
29+
* distribute work across threads within a wave,
30+
* determine how many elements each thread processes (``elements_per_thread``),
31+
* compute memory access strides, and
32+
* drive the expansion (unrolling) pass that replicates operations until the
33+
workgroup tile is covered.
34+
35+
36+
Where vector_shapes comes from
37+
-------------------------------
38+
39+
There are two cases, depending on whether ``mma_type`` is present.
40+
41+
**When mma_type is set,** ``vector_shapes`` is derived from the MMA
42+
instruction geometry. ``WaveMmaKindAttr::getShape`` returns the ``(M, N, K)``
43+
tile for the intrinsic and those sizes become the vector shape entries:
44+
45+
.. code-block:: text
46+
47+
mma_type = f32_16x16x16_f16 → getShape = (16, 16, 16)
48+
vector_shapes = {M = 16, N = 16, K = 16}
49+
50+
Additional entries may be provided for dimensions the MMA analysis does not
51+
cover (e.g. a batch dimension), and in that case both ``mma_type`` and explicit
52+
``vector_shapes`` coexist.
53+
54+
**When mma_type is absent,** ``vector_shapes`` is specified directly or derived
55+
from workgroup / tiling constraint tile sizes. In either case it must be
56+
present for the compiler to proceed.
57+
58+
In MLIR, ``vector_shapes`` entries must all be ``IntegerAttr`` values. The
59+
verifier in ``WaveDialect.cpp`` enforces this.
60+
61+
62+
The special value 0
63+
^^^^^^^^^^^^^^^^^^^
64+
65+
A vector shape of ``0`` marks a dimension as *scalar* — the wave does not tile
66+
along it. This is used for dimensions like batch (``B``) that should not
67+
contribute to the intra-wave data distribution:
68+
69+
.. code-block:: mlir
70+
71+
vector_shapes = {B = 0, M = 16, N = 16}
72+
73+
74+
Relationship to workgroup and tiling constraints
75+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
76+
77+
``vector_shapes`` and constraint tile sizes serve different purposes:
78+
79+
* **Tile size** (from ``#wave.workgroup_constraint`` or
80+
``#wave.tiling_constraint``) is the total amount of work assigned to one
81+
workgroup or one iteration of a reduction loop along a dimension.
82+
* **Vector shape** is the amount of work one wave handles in a single instance
83+
of an operation along that dimension.
84+
85+
**When mma_type is present,** the vector shapes derive from the MMA geometry
86+
and are typically smaller than the constraint tile sizes. The expansion pass
87+
(which runs on the Python/FX side) replicates each
88+
operation to cover the tile. For example, with ``BLOCK_M = 64``,
89+
``waves_per_block = [2, 2, 1]``, and ``mma_type = f32_16x16x16_f16``
90+
(vector shape 16 for M):
91+
92+
.. code-block:: text
93+
94+
expansion_count = ceil(64 / (2 × 16)) = 2
95+
96+
The MLIR IR only sees the already-expanded result: two ``wave.mma`` ops along M
97+
rather than one. The ``vector_shapes`` remain on the
98+
``#wave.hardware_constraint`` for verification and for passes that need to
99+
reason about the per-wave tile.
100+
101+
**When mma_type is absent,** the MLIR verifier enforces that each
102+
``vector_shapes`` entry **matches** the resolved tile size from the
103+
corresponding ``#wave.workgroup_constraint`` or ``#wave.tiling_constraint`` for
104+
that dimension. Unlike with mma_operations that have a fixed size, element wise operations
105+
can operate on any number of elements_per_thread and thus don't need to be expanded multiple times.
106+
A mismatch is a verification error:
107+
108+
.. code-block:: mlir
109+
110+
// ERROR: vector_shapes entry 'M' (16) does not match
111+
// workgroup constraint tile size (32)
112+
#wave.hardware_constraint<threads_per_wave = 64, vector_shapes = {M = 16}>
113+
114+
This means that in non-MMA programs, there is no separate expansion step:
115+
``vector_shapes`` equals the tile size and each operation appears exactly once
116+
per dimension.
117+
118+
119+
MMA kind and intrinsic shapes
120+
------------------------------
121+
122+
``WaveMmaKindEnum`` enumerates hardware matrix multiply intrinsics. Each
123+
variant encodes the output element type, tile shape (M×N×K), and input element
124+
type. Examples:
125+
126+
.. code-block:: mlir
127+
128+
#wave.mma_kind<f32_16x16x16_f16> // (M=16, N=16, K=16)
129+
#wave.mma_kind<f32_32x32x8_f16> // (M=32, N=32, K=8)
130+
#wave.mma_kind<f32_16x16x128_f8f6f4> // (M=16, N=16, K=128)
131+
132+
``WaveMmaKindAttr::getShape(ctx, kind)`` returns the ``(M, N, K)`` tuple.
133+
134+
The ``kind`` attribute on ``wave.mma`` may differ from the ``mma_type`` on the
135+
hardware constraint. When ``kind`` is absent, the
136+
``PropagateDefaultsFromConstraints`` pass fills it from the hardware
137+
constraint's ``mma_type``. When multiple ``wave.mma`` ops exist in the same
138+
function, each carries its own ``kind`` and its own effective vector shapes.
139+
140+
141+
Relationship to elements_per_thread
142+
-------------------------------------
143+
144+
``elements_per_thread`` is an optional ``I64Attr`` on ``wave.read`` and
145+
``wave.write``. It specifies how many contiguous elements a single thread
146+
loads or stores in one operation instance:
147+
148+
.. code-block:: mlir
149+
150+
%0 = wave.read %mem { elements_per_thread = 8 }
151+
: (!wave.tensor<[@M, @K] of f16, <global>>)
152+
-> !wave.tensor<[@M, @K] of f16, <register>>
153+
154+
``elements_per_thread`` is related to ``vector_shapes`` conceptually: the
155+
vector shape for a dimension gives the total elements a wave handles, and
156+
dividing by ``threads_per_wave`` (for a reduction dimension) or accounting for
157+
thread count per workgroup dimension gives the per-thread count. The
158+
``PropagateElementsPerThread`` pass can infer ``elements_per_thread`` from the
159+
hardware constraint when it is not explicitly provided.

0 commit comments

Comments
 (0)