-
Notifications
You must be signed in to change notification settings - Fork 32
Add MLIR verifier to check HardwareConstraint.vector_shapes consistency with tile sizes #1094
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4b2f780
283d76b
eed86ef
7e0745a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,159 @@ | ||||||
| Vector Shapes and Hardware Constraints | ||||||
| ====================================== | ||||||
|
|
||||||
| This document describes the ``vector_shapes`` field on | ||||||
| ``#wave.hardware_constraint`` and how it relates to ``mma_type``, | ||||||
| ``elements_per_thread``, and the constraint system in the Water IR. | ||||||
|
|
||||||
|
|
||||||
| Overview | ||||||
| -------- | ||||||
|
|
||||||
| ``vector_shapes`` is an optional ``DictionaryAttr`` on | ||||||
| ``#wave.hardware_constraint``. Each entry maps a dimension name (a string | ||||||
| matching a ``#wave.symbol``) to an integer specifying how many elements a | ||||||
| single wave processes along that dimension in one instance of an operation | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it per wave or per workgroup? I keep being confused by it. Also, does this imply that the product of values in Similarly, there is some implicit notion of "hardware-compatible" sizes, like the ones from the mma kind, but also read/write widths. Does this mean that vector shapes should be a multiple of those?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, modulo potential masking that may or may not happen.
No,
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AFAIK, masking happens for read/write and the replacement element is 0, which may be problematic if the value is then used as RHS of a division or a modulo... Maybe let's emit a warning (beware that you may have to use
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After a live discussion, this appears to be an additional tiling level, so vector shape is per operation instance inside wave, not just one wave. |
||||||
| before expansion has replicated it. | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| .. code-block:: mlir | ||||||
|
|
||||||
| #wave.hardware_constraint< | ||||||
| threads_per_wave = 64, | ||||||
| waves_per_block = [2, 2, 1], | ||||||
| mma_type = #wave.mma_kind<f32_16x16x16_f16>, | ||||||
| vector_shapes = {M = 16, N = 16, K = 16}, | ||||||
| max_bits_per_load = 128> | ||||||
|
|
||||||
| ``vector_shapes`` is the central piece of information the compiler uses to: | ||||||
|
|
||||||
| * distribute work across threads within a wave, | ||||||
| * determine how many elements each thread processes (``elements_per_thread``), | ||||||
| * compute memory access strides, and | ||||||
| * drive the expansion (unrolling) pass that replicates operations until the | ||||||
| workgroup tile is covered. | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. workgroup or wave?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wave. And workgroup tile should rather be block.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Workgroup is an alias for block. We should try to use AMD terminology consistently.
Then we should have a check for them matching or, if this doesn't break any existing functionality, that only one of the two mechanisms is present.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
|
|
||||||
| Where vector_shapes comes from | ||||||
| ------------------------------- | ||||||
|
|
||||||
| There are two cases, depending on whether ``mma_type`` is present. | ||||||
|
|
||||||
| **When mma_type is set,** ``vector_shapes`` is derived from the MMA | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is worth noting that indicidual mma instruction may override the mma kind provided in hardware constraints... |
||||||
| instruction geometry. ``WaveMmaKindAttr::getShape`` returns the ``(M, N, K)`` | ||||||
| tile for the intrinsic and those sizes become the vector shape entries: | ||||||
|
|
||||||
| .. code-block:: text | ||||||
|
|
||||||
| mma_type = f32_16x16x16_f16 → getShape = (16, 16, 16) | ||||||
| vector_shapes = {M = 16, N = 16, K = 16} | ||||||
|
|
||||||
| Additional entries may be provided for dimensions the MMA analysis does not | ||||||
| cover (e.g. a batch dimension), and in that case both ``mma_type`` and explicit | ||||||
| ``vector_shapes`` coexist. | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code overwrites mma shapes by hw shapes AFAICS: wave/wave_lang/kernel/wave/utils/mma_utils.py Lines 141 to 142 in f591a21
(and the coding agent doesn't see it unless you shove its nose in it).
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is being resolved in #1141 |
||||||
|
|
||||||
| **When mma_type is absent,** ``vector_shapes`` is specified directly or derived | ||||||
| from workgroup / tiling constraint tile sizes. In either case it must be | ||||||
| present for the compiler to proceed. | ||||||
|
Comment on lines
+55
to
+56
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't get it: if it is derived, how it "must be present"?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it must be either user provided or derivable from the constraints. PyWave has a separate
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Explain that in the document. |
||||||
|
|
||||||
| In MLIR, ``vector_shapes`` entries must all be ``IntegerAttr`` values. The | ||||||
| verifier in ``WaveDialect.cpp`` enforces this. | ||||||
|
|
||||||
|
|
||||||
| The special value 0 | ||||||
| ^^^^^^^^^^^^^^^^^^^ | ||||||
|
|
||||||
| A vector shape of ``0`` marks a dimension as *scalar* — the wave does not tile | ||||||
| along it. This is used for dimensions like batch (``B``) that should not | ||||||
| contribute to the intra-wave data distribution: | ||||||
|
Comment on lines
+62
to
+67
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand the "scalar" part even though it seems to be present in the comments. The fact that a dimension isn't tiled doesn't make it scalar, it is just going to be replicated/unrolled by expansion. What deserves an investigation here is whether one is allowed to set |
||||||
|
|
||||||
| .. code-block:: mlir | ||||||
|
|
||||||
| vector_shapes = {B = 0, M = 16, N = 16} | ||||||
|
|
||||||
|
|
||||||
| Relationship to workgroup and tiling constraints | ||||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||||
|
|
||||||
| ``vector_shapes`` and constraint tile sizes serve different purposes: | ||||||
|
|
||||||
| * **Tile size** (from ``#wave.workgroup_constraint`` or | ||||||
| ``#wave.tiling_constraint``) is the total amount of work assigned to one | ||||||
| workgroup or one iteration of a reduction loop along a dimension. | ||||||
| * **Vector shape** is the amount of work one wave handles in a single instance | ||||||
| of an operation along that dimension. | ||||||
|
Comment on lines
+79
to
+83
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about wave constraints? It looks like they also specify wave tile size. My (potentially incorrect) undesrtanding is that
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that also seems redundant.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's have an error when both are present on the same dimension then. |
||||||
|
|
||||||
| **When mma_type is present,** the vector shapes derive from the MMA geometry | ||||||
| and are typically smaller than the constraint tile sizes. The expansion pass | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This implicitly means "workgroup constraint" tiles sizes. This is fully expected, I think we should verify that that wave tile sizes are systematically less than or equal to workgroup constraint tile sizes for the same symbol. |
||||||
| (which runs on the Python/FX side) replicates each | ||||||
| operation to cover the tile. For example, with ``BLOCK_M = 64``, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. which of the tile sizes? :) |
||||||
| ``waves_per_block = [2, 2, 1]``, and ``mma_type = f32_16x16x16_f16`` | ||||||
| (vector shape 16 for M): | ||||||
|
|
||||||
| .. code-block:: text | ||||||
|
|
||||||
| expansion_count = ceil(64 / (2 × 16)) = 2 | ||||||
|
|
||||||
| The MLIR IR only sees the already-expanded result: two ``wave.mma`` ops along M | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not exactly true since it depends on where the conversion happens. |
||||||
| rather than one. The ``vector_shapes`` remain on the | ||||||
| ``#wave.hardware_constraint`` for verification and for passes that need to | ||||||
| reason about the per-wave tile. | ||||||
|
|
||||||
| **When mma_type is absent,** the MLIR verifier enforces that each | ||||||
| ``vector_shapes`` entry **matches** the resolved tile size from the | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not convinced that vector shapes must match workgroup tile size. Not saying I'm right, but please convince me one way or another.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the presence of
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, so this joins the thread above: vector shapes and wave constraints are redundant. I think this has more to do with the redundancy, and less to do with the kind of ops we have. At HW level, elementwise ops should still operate on 1/2/4 elements, it's only a question of which level of the stack does the unrolling: wave, mlir or llvm. |
||||||
| corresponding ``#wave.workgroup_constraint`` or ``#wave.tiling_constraint`` for | ||||||
| that dimension. Unlike with mma_operations that have a fixed size, element wise operations | ||||||
| can operate on any number of elements_per_thread and thus don't need to be expanded multiple times. | ||||||
|
Comment on lines
+104
to
+105
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This dosen't sound right. None of the operations may actually opeate on arbirary number of elements per thread, there are instructions that usually work on 1, 2 or 4 elements per thread, sometimes depending on element size. Hence the expansion process. It may just replicate some operations more than some others, depending on the "native" size they support.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, on the hardware level you are right. But from how PyWave lowers to mlir, we can just have an operation like
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://mlir.llvm.org/docs/Dialects/Vector/#vectorcontract-vectorcontractionop this one would. Same as above, this is more about where we choose to do the unrolling. Let's double check whether this happens in wave expansion or not. I'd expect it to happen. |
||||||
| A mismatch is a verification error: | ||||||
|
|
||||||
| .. code-block:: mlir | ||||||
|
|
||||||
| // ERROR: vector_shapes entry 'M' (16) does not match | ||||||
| // workgroup constraint tile size (32) | ||||||
| #wave.hardware_constraint<threads_per_wave = 64, vector_shapes = {M = 16}> | ||||||
|
|
||||||
| This means that in non-MMA programs, there is no separate expansion step: | ||||||
| ``vector_shapes`` equals the tile size and each operation appears exactly once | ||||||
| per dimension. | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is correct.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there might be in pywave... But this line relates to the verifier change is just made which forbids this.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We shouldn't have a verifier that forbids things we can express in pywave, unless as the last resort for unsound semantics. |
||||||
|
|
||||||
|
|
||||||
| MMA kind and intrinsic shapes | ||||||
| ------------------------------ | ||||||
|
|
||||||
| ``WaveMmaKindEnum`` enumerates hardware matrix multiply intrinsics. Each | ||||||
| variant encodes the output element type, tile shape (M×N×K), and input element | ||||||
| type. Examples: | ||||||
|
|
||||||
| .. code-block:: mlir | ||||||
|
|
||||||
| #wave.mma_kind<f32_16x16x16_f16> // (M=16, N=16, K=16) | ||||||
| #wave.mma_kind<f32_32x32x8_f16> // (M=32, N=32, K=8) | ||||||
| #wave.mma_kind<f32_16x16x128_f8f6f4> // (M=16, N=16, K=128) | ||||||
|
|
||||||
| ``WaveMmaKindAttr::getShape(ctx, kind)`` returns the ``(M, N, K)`` tuple. | ||||||
|
|
||||||
| The ``kind`` attribute on ``wave.mma`` may differ from the ``mma_type`` on the | ||||||
| hardware constraint. When ``kind`` is absent, the | ||||||
| ``PropagateDefaultsFromConstraints`` pass fills it from the hardware | ||||||
| constraint's ``mma_type``. When multiple ``wave.mma`` ops exist in the same | ||||||
| function, each carries its own ``kind`` and its own effective vector shapes. | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Carries the vector shapes where?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Inside of
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but this paragraph seems to be about MLIR. And it doesn't carry vector shapes AFAIK. I may end up adding it though. |
||||||
|
|
||||||
|
|
||||||
| Relationship to elements_per_thread | ||||||
| ------------------------------------- | ||||||
|
|
||||||
| ``elements_per_thread`` is an optional ``I64Attr`` on ``wave.read`` and | ||||||
| ``wave.write``. It specifies how many contiguous elements a single thread | ||||||
| loads or stores in one operation instance: | ||||||
|
|
||||||
| .. code-block:: mlir | ||||||
|
|
||||||
| %0 = wave.read %mem { elements_per_thread = 8 } | ||||||
| : (!wave.tensor<[@M, @K] of f16, <global>>) | ||||||
| -> !wave.tensor<[@M, @K] of f16, <register>> | ||||||
|
|
||||||
| ``elements_per_thread`` is related to ``vector_shapes`` conceptually: the | ||||||
| vector shape for a dimension gives the total elements a wave handles, and | ||||||
| dividing by ``threads_per_wave`` (for a reduction dimension) or accounting for | ||||||
| thread count per workgroup dimension gives the per-thread count. The | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can't parse this. My understanding is that EPT should be directly related: we know the number of elements per workgroup, there's a
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is conceptually and directly related :) As everything in wave.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But I still can't parse what is says. Consider rephrasing.
Why specifically for a reduction dimension? Threads per wave is a scalar, how do we divide a dictionary per that?
Accounting how exactly? Is this different from division above? Also, thread count per workgroup dimension subsumes thread count per wave. |
||||||
| ``PropagateElementsPerThread`` pass can infer ``elements_per_thread`` from the | ||||||
| hardware constraint when it is not explicitly provided. | ||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -222,6 +222,7 @@ verifyConstraints(ArrayAttr constraints, | |||||
| // * The number of workgroups should be greater than or equal to one. | ||||||
| llvm::SmallDenseMap<wave::WaveSymbolAttr, int64_t> resolvedWorkgroupSizes( | ||||||
| workgroupConstraints.size()); | ||||||
| llvm::SmallDenseMap<wave::WaveSymbolAttr, int64_t> resolvedSizes; | ||||||
| llvm::SmallDenseSet<wave::WaveWorkgroupDimAttr, 4> assignedDims; | ||||||
| llvm::SmallDenseSet<wave::WaveWorkgroupDimAttr, 4> needsPrimaryDim; | ||||||
| for (auto &&[symbol, constraint] : workgroupConstraints) { | ||||||
|
|
@@ -250,6 +251,7 @@ verifyConstraints(ArrayAttr constraints, | |||||
|
|
||||||
| int64_t workgroupSize = evaluated->front(); | ||||||
| resolvedWorkgroupSizes[symbol] = workgroupSize; | ||||||
| resolvedSizes[symbol] = workgroupSize; | ||||||
|
|
||||||
| std::optional<llvm::SmallVector<int64_t>> resolvedDims = | ||||||
| wave::resolveSymbolNames(symbol, hyperparams); | ||||||
|
|
@@ -310,15 +312,16 @@ verifyConstraints(ArrayAttr constraints, | |||||
| << workgroupSize << " for dimension: " << symbol; | ||||||
| } | ||||||
| resolvedWaveCounts[symbol] = numWaves; | ||||||
| resolvedSizes[symbol] = resolvedWaveSize; | ||||||
| } | ||||||
|
|
||||||
| // verify consistency between wave constraints and waves_per_block | ||||||
| // * If both wave constraints and waves_per_block are present, the computed | ||||||
| // number of waves per dimension should match the waves_per_block attribute. | ||||||
| if (hardwareConstraint && !hardwareConstraint.getWavesPerBlock().empty() && | ||||||
| if (hardwareConstraint && hardwareConstraint.getWavesPerBlock() && | ||||||
| !waveConstraints.empty()) { | ||||||
| llvm::ArrayRef<unsigned> wavesPerBlock = | ||||||
| hardwareConstraint.getWavesPerBlock(); | ||||||
| llvm::ArrayRef<int32_t> wavesPerBlock = | ||||||
| hardwareConstraint.getWavesPerBlock().asArrayRef(); | ||||||
| for (auto &&[symbol, waveConstraint] : waveConstraints) { | ||||||
| wave::WorkgroupConstraintAttr wgConstraint = workgroupConstraints[symbol]; | ||||||
| unsigned wgDim = | ||||||
|
|
@@ -335,6 +338,8 @@ verifyConstraints(ArrayAttr constraints, | |||||
|
|
||||||
| // verify TilingConstraint | ||||||
| // * The number of tiles should be greater than or equal to one. | ||||||
| llvm::SmallDenseMap<wave::WaveSymbolAttr, int64_t> resolvedTilingSizes( | ||||||
| tilingConstraints.size()); | ||||||
| for (auto &&[symbol, constraint] : tilingConstraints) { | ||||||
| std::optional<llvm::SmallVector<int64_t>> evaluated = | ||||||
| wave::evaluateMapWithHyperparams(constraint.getTileSize().getMap(), | ||||||
|
|
@@ -351,6 +356,8 @@ verifyConstraints(ArrayAttr constraints, | |||||
| "failed to resolve dimesion symbol"); | ||||||
|
|
||||||
| int64_t resolvedTileSize = evaluated->front(); | ||||||
| resolvedTilingSizes[symbol] = resolvedTileSize; | ||||||
| resolvedSizes[symbol] = resolvedTileSize; | ||||||
| int64_t resolvedDim = resolvedDims->front(); | ||||||
| int64_t numTiles = resolvedDim / resolvedTileSize; | ||||||
| if (numTiles < 1) { | ||||||
|
|
@@ -359,6 +366,38 @@ verifyConstraints(ArrayAttr constraints, | |||||
| } | ||||||
| } | ||||||
|
|
||||||
| // Verify consistency between constraints and vector_shapes (when mma_type | ||||||
| // is absent). Each vector_shapes entry must match the resolved tile size | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| // from the most specific constraint for that dimension: WaveConstraint > | ||||||
| // WorkgroupConstraint > TilingConstraint. | ||||||
| if (hardwareConstraint && hardwareConstraint.getVectorShapes() && | ||||||
| !hardwareConstraint.getMmaType()) { | ||||||
| DictionaryAttr vectorShapes = hardwareConstraint.getVectorShapes(); | ||||||
| for (NamedAttribute dimension : vectorShapes) { | ||||||
| llvm::StringRef symbolName = dimension.getName().getValue(); | ||||||
| int64_t size = llvm::cast<IntegerAttr>(dimension.getValue()).getInt(); | ||||||
|
|
||||||
| wave::WaveSymbolAttr symbol = | ||||||
| wave::WaveSymbolAttr::get(hyperparams.getContext(), symbolName); | ||||||
|
|
||||||
| auto it = resolvedSizes.find(symbol); | ||||||
|
|
||||||
| if (it == resolvedSizes.end()) { | ||||||
| // Batch dimensions may not be present in the resolved sizes map. | ||||||
| continue; | ||||||
| } | ||||||
|
|
||||||
| int64_t resolvedSize = it->second; | ||||||
|
|
||||||
| if (size > resolvedSize) { | ||||||
| return emitError() << "vector_shapes entry '" << symbolName << "' (" | ||||||
| << size | ||||||
| << ") is greater than the resolved tile size (" | ||||||
| << resolvedSize << ") for dimension: " << symbol; | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We also want a check what |
||||||
| return llvm::success(); | ||||||
| } | ||||||
|
|
||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit of a mess, but we alredy have
docs/wave/ir_design_notes.rst, could we keep everything there?