|
| 1 | +Vector Shapes and Hardware Constraints |
| 2 | +====================================== |
| 3 | + |
| 4 | +This document describes the ``vector_shapes`` field on |
| 5 | +``#wave.hardware_constraint`` and how it relates to ``mma_type``, |
| 6 | +``elements_per_thread``, and the constraint system in the Water IR. |
| 7 | + |
| 8 | + |
| 9 | +Overview |
| 10 | +-------- |
| 11 | + |
| 12 | +``vector_shapes`` is an optional ``DictionaryAttr`` on |
| 13 | +``#wave.hardware_constraint``. Each entry maps a dimension name (a string |
| 14 | +matching a ``#wave.symbol``) to an integer specifying how many elements a |
| 15 | +single wave processes along that dimension in one instance of an operation |
| 16 | +before expansion has replicated it. |
| 17 | + |
| 18 | +.. code-block:: mlir |
| 19 | +
|
| 20 | + #wave.hardware_constraint< |
| 21 | + threads_per_wave = 64, |
| 22 | + waves_per_block = [2, 2, 1], |
| 23 | + mma_type = #wave.mma_kind<f32_16x16x16_f16>, |
| 24 | + vector_shapes = {M = 16, N = 16, K = 16}, |
| 25 | + max_bits_per_load = 128> |
| 26 | +
|
| 27 | +``vector_shapes`` is the central piece of information the compiler uses to: |
| 28 | + |
| 29 | +* distribute work across threads within a wave, |
| 30 | +* determine how many elements each thread processes (``elements_per_thread``), |
| 31 | +* compute memory access strides, and |
| 32 | +* drive the expansion (unrolling) pass that replicates operations until the |
| 33 | + workgroup tile is covered. |
| 34 | + |
| 35 | + |
| 36 | +Where vector_shapes comes from |
| 37 | +------------------------------- |
| 38 | + |
| 39 | +There are two cases, depending on whether ``mma_type`` is present. |
| 40 | + |
| 41 | +**When mma_type is set,** ``vector_shapes`` is derived from the MMA |
| 42 | +instruction geometry. ``WaveMmaKindAttr::getShape`` returns the ``(M, N, K)`` |
| 43 | +tile for the intrinsic and those sizes become the vector shape entries: |
| 44 | + |
| 45 | +.. code-block:: text |
| 46 | +
|
| 47 | + mma_type = f32_16x16x16_f16 → getShape = (16, 16, 16) |
| 48 | + vector_shapes = {M = 16, N = 16, K = 16} |
| 49 | +
|
| 50 | +Additional entries may be provided for dimensions the MMA analysis does not |
| 51 | +cover (e.g. a batch dimension), and in that case both ``mma_type`` and explicit |
| 52 | +``vector_shapes`` coexist. |
| 53 | + |
| 54 | +**When mma_type is absent,** ``vector_shapes`` is specified directly or derived |
| 55 | +from workgroup / tiling constraint tile sizes. In either case it must be |
| 56 | +present for the compiler to proceed. |
| 57 | + |
| 58 | +In MLIR, ``vector_shapes`` entries must all be ``IntegerAttr`` values. The |
| 59 | +verifier in ``WaveDialect.cpp`` enforces this. |
| 60 | + |
| 61 | + |
| 62 | +The special value 0 |
| 63 | +^^^^^^^^^^^^^^^^^^^ |
| 64 | + |
| 65 | +A vector shape of ``0`` marks a dimension as *scalar* — the wave does not tile |
| 66 | +along it. This is used for dimensions like batch (``B``) that should not |
| 67 | +contribute to the intra-wave data distribution: |
| 68 | + |
| 69 | +.. code-block:: mlir |
| 70 | +
|
| 71 | + vector_shapes = {B = 0, M = 16, N = 16} |
| 72 | +
|
| 73 | +
|
| 74 | +Relationship to workgroup and tiling constraints |
| 75 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 76 | + |
| 77 | +``vector_shapes`` and constraint tile sizes serve different purposes: |
| 78 | + |
| 79 | +* **Tile size** (from ``#wave.workgroup_constraint`` or |
| 80 | + ``#wave.tiling_constraint``) is the total amount of work assigned to one |
| 81 | + workgroup or one iteration of a reduction loop along a dimension. |
| 82 | +* **Vector shape** is the amount of work one wave handles in a single instance |
| 83 | + of an operation along that dimension. |
| 84 | + |
| 85 | +**When mma_type is present,** the vector shapes derive from the MMA geometry |
| 86 | +and are typically smaller than the constraint tile sizes. The expansion pass |
| 87 | +(which runs on the Python/FX side) replicates each |
| 88 | +operation to cover the tile. For example, with ``BLOCK_M = 64``, |
| 89 | +``waves_per_block = [2, 2, 1]``, and ``mma_type = f32_16x16x16_f16`` |
| 90 | +(vector shape 16 for M): |
| 91 | + |
| 92 | +.. code-block:: text |
| 93 | +
|
| 94 | + expansion_count = ceil(64 / (2 × 16)) = 2 |
| 95 | +
|
| 96 | +The MLIR IR only sees the already-expanded result: two ``wave.mma`` ops along M |
| 97 | +rather than one. The ``vector_shapes`` remain on the |
| 98 | +``#wave.hardware_constraint`` for verification and for passes that need to |
| 99 | +reason about the per-wave tile. |
| 100 | + |
| 101 | +**When mma_type is absent,** the MLIR verifier enforces that each |
| 102 | +``vector_shapes`` entry **matches** the resolved tile size from the |
| 103 | +corresponding ``#wave.workgroup_constraint`` or ``#wave.tiling_constraint`` for |
| 104 | +that dimension. Unlike with mma_operations that have a fixed size, element wise operations |
| 105 | +can operate on any number of elements_per_thread and thus don't need to be expanded multiple times. |
| 106 | +A mismatch is a verification error: |
| 107 | + |
| 108 | +.. code-block:: mlir |
| 109 | +
|
| 110 | + // ERROR: vector_shapes entry 'M' (16) does not match |
| 111 | + // workgroup constraint tile size (32) |
| 112 | + #wave.hardware_constraint<threads_per_wave = 64, vector_shapes = {M = 16}> |
| 113 | +
|
| 114 | +This means that in non-MMA programs, there is no separate expansion step: |
| 115 | +``vector_shapes`` equals the tile size and each operation appears exactly once |
| 116 | +per dimension. |
| 117 | + |
| 118 | + |
| 119 | +MMA kind and intrinsic shapes |
| 120 | +------------------------------ |
| 121 | + |
| 122 | +``WaveMmaKindEnum`` enumerates hardware matrix multiply intrinsics. Each |
| 123 | +variant encodes the output element type, tile shape (M×N×K), and input element |
| 124 | +type. Examples: |
| 125 | + |
| 126 | +.. code-block:: mlir |
| 127 | +
|
| 128 | + #wave.mma_kind<f32_16x16x16_f16> // (M=16, N=16, K=16) |
| 129 | + #wave.mma_kind<f32_32x32x8_f16> // (M=32, N=32, K=8) |
| 130 | + #wave.mma_kind<f32_16x16x128_f8f6f4> // (M=16, N=16, K=128) |
| 131 | +
|
| 132 | +``WaveMmaKindAttr::getShape(ctx, kind)`` returns the ``(M, N, K)`` tuple. |
| 133 | + |
| 134 | +The ``kind`` attribute on ``wave.mma`` may differ from the ``mma_type`` on the |
| 135 | +hardware constraint. When ``kind`` is absent, the |
| 136 | +``PropagateDefaultsFromConstraints`` pass fills it from the hardware |
| 137 | +constraint's ``mma_type``. When multiple ``wave.mma`` ops exist in the same |
| 138 | +function, each carries its own ``kind`` and its own effective vector shapes. |
| 139 | + |
| 140 | + |
| 141 | +Relationship to elements_per_thread |
| 142 | +------------------------------------- |
| 143 | + |
| 144 | +``elements_per_thread`` is an optional ``I64Attr`` on ``wave.read`` and |
| 145 | +``wave.write``. It specifies how many contiguous elements a single thread |
| 146 | +loads or stores in one operation instance: |
| 147 | + |
| 148 | +.. code-block:: mlir |
| 149 | +
|
| 150 | + %0 = wave.read %mem { elements_per_thread = 8 } |
| 151 | + : (!wave.tensor<[@M, @K] of f16, <global>>) |
| 152 | + -> !wave.tensor<[@M, @K] of f16, <register>> |
| 153 | +
|
| 154 | +``elements_per_thread`` is related to ``vector_shapes`` conceptually: the |
| 155 | +vector shape for a dimension gives the total elements a wave handles, and |
| 156 | +dividing by ``threads_per_wave`` (for a reduction dimension) or accounting for |
| 157 | +thread count per workgroup dimension gives the per-thread count. The |
| 158 | +``PropagateElementsPerThread`` pass can infer ``elements_per_thread`` from the |
| 159 | +hardware constraint when it is not explicitly provided. |
0 commit comments