Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ API reference material
wave/wave


Design documentation
====================

.. toctree::
:maxdepth: 1
:caption: IR Design

ir_design

Project documentation
=====================

Expand Down
159 changes: 159 additions & 0 deletions docs/ir_design.rst
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit of a mess, but we alredy have docs/wave/ir_design_notes.rst, could we keep everything there?

Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
Vector Shapes and Hardware Constraints
======================================

This document describes the ``vector_shapes`` field on
``#wave.hardware_constraint`` and how it relates to ``mma_type``,
``elements_per_thread``, and the constraint system in the Water IR.


Overview
--------

``vector_shapes`` is an optional ``DictionaryAttr`` on
``#wave.hardware_constraint``. Each entry maps a dimension name (a string
matching a ``#wave.symbol``) to an integer specifying how many elements a
single wave processes along that dimension in one instance of an operation
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it per wave or per workgroup? I keep being confused by it. Also, does this imply that the product of values in vector_shapes must be a multipe of threads_per_wave?

Similarly, there is some implicit notion of "hardware-compatible" sizes, like the ones from the mma kind, but also read/write widths. Does this mean that vector shapes should be a multiple of those?

Copy link
Copy Markdown
Contributor Author

@tgymnich tgymnich Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, does this imply that the product of values in vector_shapes must be a multipe of threads_per_wave

Yes, modulo potential masking that may or may not happen.

Similarly, there is some implicit notion of "hardware-compatible" sizes, like the ones from the mma kind, but also read/write widths. Does this mean that vector shapes should be a multiple of those?

No, vector_shapes can be of any size (* it makes sense for the size to be at least a multiple of threads_per_wave), since element wise operations are not really limited to fixed sizes.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, masking happens for read/write and the replacement element is 0, which may be problematic if the value is then used as RHS of a division or a modulo... Maybe let's emit a warning (beware that you may have to use emitWarning(op->getLoc()) instead of op->emitWarning due to a verification order bug) when it is not divisible.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a live discussion, this appears to be an additional tiling level, so vector shape is per operation instance inside wave, not just one wave.

before expansion has replicated it.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
before expansion has replicated it.
after the expansion process has replicated it.


.. code-block:: mlir

#wave.hardware_constraint<
threads_per_wave = 64,
waves_per_block = [2, 2, 1],
mma_type = #wave.mma_kind<f32_16x16x16_f16>,
vector_shapes = {M = 16, N = 16, K = 16},
max_bits_per_load = 128>

``vector_shapes`` is the central piece of information the compiler uses to:

* distribute work across threads within a wave,
* determine how many elements each thread processes (``elements_per_thread``),
* compute memory access strides, and
* drive the expansion (unrolling) pass that replicates operations until the
workgroup tile is covered.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

workgroup or wave?

Copy link
Copy Markdown
Contributor Author

@tgymnich tgymnich Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wave. And workgroup tile should rather be block.
waves_per_block also models the same concept as WorkgroupConstraint and WaveConstraint

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wave. And workgroup tile should rather be block.

Workgroup is an alias for block. We should try to use AMD terminology consistently.

waves_per_block also models the same concept as WorkgroupConstraint and WaveConstraint

Then we should have a check for them matching or, if this doesn't break any existing functionality, that only one of the two mechanisms is present.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
workgroup tile is covered.
wave tile is covered.



Where vector_shapes comes from
-------------------------------

There are two cases, depending on whether ``mma_type`` is present.

**When mma_type is set,** ``vector_shapes`` is derived from the MMA
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is worth noting that indicidual mma instruction may override the mma kind provided in hardware constraints...

instruction geometry. ``WaveMmaKindAttr::getShape`` returns the ``(M, N, K)``
tile for the intrinsic and those sizes become the vector shape entries:

.. code-block:: text

mma_type = f32_16x16x16_f16 → getShape = (16, 16, 16)
vector_shapes = {M = 16, N = 16, K = 16}

Additional entries may be provided for dimensions the MMA analysis does not
cover (e.g. a batch dimension), and in that case both ``mma_type`` and explicit
``vector_shapes`` coexist.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code overwrites mma shapes by hw shapes AFAICS:

if hardware_constraint.vector_shapes:
custom.vector_shapes.update(hardware_constraint.vector_shapes)

(and the coding agent doesn't see it unless you shove its nose in it).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is being resolved in #1141


**When mma_type is absent,** ``vector_shapes`` is specified directly or derived
from workgroup / tiling constraint tile sizes. In either case it must be
present for the compiler to proceed.
Comment on lines +55 to +56
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get it: if it is derived, how it "must be present"?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it must be either user provided or derivable from the constraints. PyWave has a separate vector_shapes concept per graph node.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain that in the document.


In MLIR, ``vector_shapes`` entries must all be ``IntegerAttr`` values. The
verifier in ``WaveDialect.cpp`` enforces this.


The special value 0
^^^^^^^^^^^^^^^^^^^

A vector shape of ``0`` marks a dimension as *scalar* — the wave does not tile
along it. This is used for dimensions like batch (``B``) that should not
contribute to the intra-wave data distribution:
Comment on lines +62 to +67
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the "scalar" part even though it seems to be present in the comments. The fact that a dimension isn't tiled doesn't make it scalar, it is just going to be replicated/unrolled by expansion. What deserves an investigation here is whether one is allowed to set 0 as vector shape for a symbol that is mapped to blocks/threads and, symmetrically, whether one is allowed to set a non-zero as vector shape for a dimension that isn't mapped.


.. code-block:: mlir

vector_shapes = {B = 0, M = 16, N = 16}


Relationship to workgroup and tiling constraints
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

``vector_shapes`` and constraint tile sizes serve different purposes:

* **Tile size** (from ``#wave.workgroup_constraint`` or
``#wave.tiling_constraint``) is the total amount of work assigned to one
workgroup or one iteration of a reduction loop along a dimension.
* **Vector shape** is the amount of work one wave handles in a single instance
of an operation along that dimension.
Comment on lines +79 to +83
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about wave constraints? It looks like they also specify wave tile size. My (potentially incorrect) undesrtanding is that vector_shape and wave constraints are redundant, but I am easily convinced otherwise.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that also seems redundant.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have an error when both are present on the same dimension then.


**When mma_type is present,** the vector shapes derive from the MMA geometry
and are typically smaller than the constraint tile sizes. The expansion pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implicitly means "workgroup constraint" tiles sizes. This is fully expected, I think we should verify that that wave tile sizes are systematically less than or equal to workgroup constraint tile sizes for the same symbol.

(which runs on the Python/FX side) replicates each
operation to cover the tile. For example, with ``BLOCK_M = 64``,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which of the tile sizes? :)

``waves_per_block = [2, 2, 1]``, and ``mma_type = f32_16x16x16_f16``
(vector shape 16 for M):

.. code-block:: text

expansion_count = ceil(64 / (2 × 16)) = 2

The MLIR IR only sees the already-expanded result: two ``wave.mma`` ops along M
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not exactly true since it depends on where the conversion happens.

rather than one. The ``vector_shapes`` remain on the
``#wave.hardware_constraint`` for verification and for passes that need to
reason about the per-wave tile.

**When mma_type is absent,** the MLIR verifier enforces that each
``vector_shapes`` entry **matches** the resolved tile size from the
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced that vector shapes must match workgroup tile size. Not saying I'm right, but please convince me one way or another.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the presence of WaveConstraint it must match that. Since element wise ops can operate on any vector size, we don't need unrolling like we do for mma.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so this joins the thread above: vector shapes and wave constraints are redundant. I think this has more to do with the redundancy, and less to do with the kind of ops we have. At HW level, elementwise ops should still operate on 1/2/4 elements, it's only a question of which level of the stack does the unrolling: wave, mlir or llvm.

corresponding ``#wave.workgroup_constraint`` or ``#wave.tiling_constraint`` for
that dimension. Unlike with mma_operations that have a fixed size, element wise operations
can operate on any number of elements_per_thread and thus don't need to be expanded multiple times.
Comment on lines +104 to +105
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dosen't sound right. None of the operations may actually opeate on arbirary number of elements per thread, there are instructions that usually work on 1, 2 or 4 elements per thread, sometimes depending on element size. Hence the expansion process. It may just replicate some operations more than some others, depending on the "native" size they support.

Copy link
Copy Markdown
Contributor Author

@tgymnich tgymnich Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, on the hardware level you are right. But from how PyWave lowers to mlir, we can just have an operation like %2 = math.exp %arg0 : vector<4xf32> operating on an arbitrary number of vector elements, which would not work with mma. MMA seems to be the only kind of operation like that.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://mlir.llvm.org/docs/Dialects/Vector/#vectorcontract-vectorcontractionop this one would. Same as above, this is more about where we choose to do the unrolling. Let's double check whether this happens in wave expansion or not. I'd expect it to happen.

A mismatch is a verification error:

.. code-block:: mlir

// ERROR: vector_shapes entry 'M' (16) does not match
// workgroup constraint tile size (32)
#wave.hardware_constraint<threads_per_wave = 64, vector_shapes = {M = 16}>

This means that in non-MMA programs, there is no separate expansion step:
``vector_shapes`` equals the tile size and each operation appears exactly once
per dimension.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is correct.

Copy link
Copy Markdown
Contributor Author

@tgymnich tgymnich Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there might be in pywave... But this line relates to the verifier change is just made which forbids this.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't have a verifier that forbids things we can express in pywave, unless as the last resort for unsound semantics.



MMA kind and intrinsic shapes
------------------------------

``WaveMmaKindEnum`` enumerates hardware matrix multiply intrinsics. Each
variant encodes the output element type, tile shape (M×N×K), and input element
type. Examples:

.. code-block:: mlir

#wave.mma_kind<f32_16x16x16_f16> // (M=16, N=16, K=16)
#wave.mma_kind<f32_32x32x8_f16> // (M=32, N=32, K=8)
#wave.mma_kind<f32_16x16x128_f8f6f4> // (M=16, N=16, K=128)

``WaveMmaKindAttr::getShape(ctx, kind)`` returns the ``(M, N, K)`` tuple.

The ``kind`` attribute on ``wave.mma`` may differ from the ``mma_type`` on the
hardware constraint. When ``kind`` is absent, the
``PropagateDefaultsFromConstraints`` pass fills it from the hardware
constraint's ``mma_type``. When multiple ``wave.mma`` ops exist in the same
function, each carries its own ``kind`` and its own effective vector shapes.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Carries the vector shapes where?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra_attrs: Additional attributes to set on the fx_node after creation (e.g., index, vector_shapes). These are not passed to the dataclass constructor.

Inside of extra_attrs in CustomOp

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but this paragraph seems to be about MLIR. And it doesn't carry vector shapes AFAIK. I may end up adding it though.



Relationship to elements_per_thread
-------------------------------------

``elements_per_thread`` is an optional ``I64Attr`` on ``wave.read`` and
``wave.write``. It specifies how many contiguous elements a single thread
loads or stores in one operation instance:

.. code-block:: mlir

%0 = wave.read %mem { elements_per_thread = 8 }
: (!wave.tensor<[@M, @K] of f16, <global>>)
-> !wave.tensor<[@M, @K] of f16, <register>>

``elements_per_thread`` is related to ``vector_shapes`` conceptually: the
vector shape for a dimension gives the total elements a wave handles, and
dividing by ``threads_per_wave`` (for a reduction dimension) or accounting for
thread count per workgroup dimension gives the per-thread count. The
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't parse this. My understanding is that EPT should be directly related: we know the number of elements per workgroup, there's a waves_per_block in hw constraints (which yet again may be redundant with both wave constraints and vector_shapes), which should tells us the number of elements per wave. We also know the number of threads in a wave, which should give us EPT. What happens when operation-specified EPT doesn't match the one that would be inferred by the process above -- I don't know.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is conceptually and directly related :) As everything in wave.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I still can't parse what is says. Consider rephrasing.

and dividing by threads_per_wave (for a reduction dimension)

Why specifically for a reduction dimension? Threads per wave is a scalar, how do we divide a dictionary per that?

accounting for thread count per workgroup dimension

Accounting how exactly? Is this different from division above?

Also, thread count per workgroup dimension subsumes thread count per wave.

``PropagateElementsPerThread`` pass can infer ``elements_per_thread`` from the
hardware constraint when it is not explicitly provided.
10 changes: 2 additions & 8 deletions water/include/water/Dialect/Wave/IR/WaveAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -237,18 +237,12 @@ def HardwareConstraintAttr : AttrDef<WaveDialect, "HardwareConstraint"> {
configuration rather than fundamental hardware constraints.
}];
let parameters = (ins "unsigned":$threads_per_wave,
OptionalArrayRefParameter<"unsigned">:$waves_per_block,
OptionalParameter<"::mlir::DenseI32ArrayAttr">:$waves_per_block,
OptionalParameter<"::wave::WaveMmaKindAttr">:$mma_type,
OptionalParameter<"::mlir::DictionaryAttr">:$vector_shapes,
DefaultValuedParameter<"unsigned", "128">:$max_bits_per_load);

let assemblyFormat = [{
`<` `threads_per_wave` `=` $threads_per_wave
(`,` `waves_per_block` `=` `[` $waves_per_block^ `]`)?
(`,` `mma_type` `=` $mma_type^)?
(`,` `vector_shapes` `=` $vector_shapes^)?
(`,` `max_bits_per_load` `=` $max_bits_per_load^)? `>`
}];
let assemblyFormat = "`<` struct(params) `>`";

let genVerifyDecl = 1;
}
Expand Down
4 changes: 2 additions & 2 deletions water/include/water/Dialect/Wave/IR/WaveInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -558,8 +558,8 @@ class IndexExprsAnalysisInit {
symbolConstraints;

// Waves-per-block extracted from the hardware constraint or computed from
// wave constraints. Always stored here, even if copied from an attribute.
llvm::SmallVector<unsigned, 3> wavesPerBlock;
// wave constraints.
llvm::SmallVector<int32_t, 3> wavesPerBlock;
};

// Lattice for propagating index expressions across wave dialect operations.
Expand Down
2 changes: 1 addition & 1 deletion water/include/water/Dialect/Wave/IR/WaveUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ llvm::LogicalResult computeWavesPerBlockFromConstraints(
const llvm::SmallDenseMap<wave::WaveSymbolAttr, wave::WaveConstraintAttr>
&waveConstraints,
wave::WaveHyperparameterAttr hyperparams,
llvm::SmallVectorImpl<unsigned> &wavesPerBlock);
llvm::SmallVectorImpl<int32_t> &wavesPerBlock);

/// Permute the shape according to the mapping.
void permuteShape(llvm::ArrayRef<wave::WaveSymbolAttr> shape,
Expand Down
4 changes: 2 additions & 2 deletions water/include/water/c/Dialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ mlirAttributeIsAHardwareConstraintAttr(MlirAttribute attr);
/// Creates a new HardwareConstraintAttr
MLIR_CAPI_EXPORTED MlirAttribute mlirHardwareConstraintAttrGet(
MlirContext mlirCtx, unsigned threadsPerWave, size_t wavesPerBlockSize,
unsigned *wavesPerBlock, MlirAttribute mmaType, MlirAttribute vectorShapes,
int32_t *wavesPerBlock, MlirAttribute mmaType, MlirAttribute vectorShapes,
unsigned maxBitsPerLoad);

/// Returns the typeID of a HardwareConstraintAttr.
Expand All @@ -486,7 +486,7 @@ MLIR_CAPI_EXPORTED unsigned
mlirHardwareConstraintAttrGetThreadsPerWave(MlirAttribute attr);
MLIR_CAPI_EXPORTED intptr_t
mlirHardwareConstraintAttrGetNumWavesPerBlock(MlirAttribute attr);
MLIR_CAPI_EXPORTED unsigned
MLIR_CAPI_EXPORTED int32_t
mlirHardwareConstraintAttrGetWavesPerBlockElem(MlirAttribute attr, intptr_t i);
MLIR_CAPI_EXPORTED MlirAttribute
mlirHardwareConstraintAttrGetMmaType(MlirAttribute attr);
Expand Down
24 changes: 15 additions & 9 deletions water/lib/CAPI/Dialects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ bool mlirAttributeIsAHardwareConstraintAttr(MlirAttribute attr) {

MlirAttribute
mlirHardwareConstraintAttrGet(MlirContext mlirCtx, unsigned threadsPerWave,
size_t wavesPerBlockSize, unsigned *wavesPerBlock,
size_t wavesPerBlockSize, int32_t *wavesPerBlock,
MlirAttribute mmaType, MlirAttribute vectorShapes,
unsigned maxBitsPerLoad) {
MLIRContext *ctx = unwrap(mlirCtx);
Expand All @@ -525,9 +525,14 @@ mlirHardwareConstraintAttrGet(MlirContext mlirCtx, unsigned threadsPerWave,
auto vectorShapesAttr =
llvm::cast_if_present<DictionaryAttr>(unwrap(vectorShapes));

DenseI32ArrayAttr wavesPerBlockAttr;
if (wavesPerBlockSize > 0)
wavesPerBlockAttr = DenseI32ArrayAttr::get(
ctx, llvm::ArrayRef(wavesPerBlock, wavesPerBlockSize));

return wrap(wave::HardwareConstraintAttr::get(
ctx, threadsPerWave, llvm::ArrayRef(wavesPerBlock, wavesPerBlockSize),
mmaTypeAttr, vectorShapesAttr, maxBitsPerLoad));
ctx, threadsPerWave, wavesPerBlockAttr, mmaTypeAttr, vectorShapesAttr,
maxBitsPerLoad));
}

MlirTypeID mlirWHardwareConstraintAttrGetTypeID() {
Expand All @@ -539,14 +544,15 @@ unsigned mlirHardwareConstraintAttrGetThreadsPerWave(MlirAttribute attr) {
.getThreadsPerWave();
}
intptr_t mlirHardwareConstraintAttrGetNumWavesPerBlock(MlirAttribute attr) {
return llvm::cast<wave::HardwareConstraintAttr>(unwrap(attr))
.getWavesPerBlock()
.size();
DenseI32ArrayAttr wpb =
llvm::cast<wave::HardwareConstraintAttr>(unwrap(attr)).getWavesPerBlock();
return wpb ? wpb.size() : 0;
}
unsigned mlirHardwareConstraintAttrGetWavesPerBlockElem(MlirAttribute attr,
intptr_t i) {
int32_t mlirHardwareConstraintAttrGetWavesPerBlockElem(MlirAttribute attr,
intptr_t i) {
return llvm::cast<wave::HardwareConstraintAttr>(unwrap(attr))
.getWavesPerBlock()[i];
.getWavesPerBlock()
.asArrayRef()[i];
}
MlirAttribute mlirHardwareConstraintAttrGetMmaType(MlirAttribute attr) {
return wrap(
Expand Down
7 changes: 3 additions & 4 deletions water/lib/Dialect/Wave/IR/WaveAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -708,12 +708,11 @@ WaveExprListAttr::verify(function_ref<InFlightDiagnostic()> emitError,

LogicalResult HardwareConstraintAttr::verify(
function_ref<InFlightDiagnostic()> emitError, unsigned threadsPerWave,
ArrayRef<unsigned> wavesPerBlock, WaveMmaKindAttr mmaType,
DenseI32ArrayAttr wavesPerBlock, WaveMmaKindAttr mmaType,
DictionaryAttr vectorShapes, unsigned maxBitsPerLoad) {

if (!(wavesPerBlock.empty() || wavesPerBlock.size() == 3))
return emitError() << "waves_per_block (" << wavesPerBlock
<< ") should have 3 elements";
if (wavesPerBlock && wavesPerBlock.size() != 3)
return emitError() << "waves_per_block should have 3 elements";

if (vectorShapes) {
for (NamedAttribute attr : vectorShapes) {
Expand Down
45 changes: 42 additions & 3 deletions water/lib/Dialect/Wave/IR/WaveDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ verifyConstraints(ArrayAttr constraints,
// * The number of workgroups should be greater than or equal to one.
llvm::SmallDenseMap<wave::WaveSymbolAttr, int64_t> resolvedWorkgroupSizes(
workgroupConstraints.size());
llvm::SmallDenseMap<wave::WaveSymbolAttr, int64_t> resolvedSizes;
llvm::SmallDenseSet<wave::WaveWorkgroupDimAttr, 4> assignedDims;
llvm::SmallDenseSet<wave::WaveWorkgroupDimAttr, 4> needsPrimaryDim;
for (auto &&[symbol, constraint] : workgroupConstraints) {
Expand Down Expand Up @@ -250,6 +251,7 @@ verifyConstraints(ArrayAttr constraints,

int64_t workgroupSize = evaluated->front();
resolvedWorkgroupSizes[symbol] = workgroupSize;
resolvedSizes[symbol] = workgroupSize;

std::optional<llvm::SmallVector<int64_t>> resolvedDims =
wave::resolveSymbolNames(symbol, hyperparams);
Expand Down Expand Up @@ -310,15 +312,16 @@ verifyConstraints(ArrayAttr constraints,
<< workgroupSize << " for dimension: " << symbol;
}
resolvedWaveCounts[symbol] = numWaves;
resolvedSizes[symbol] = resolvedWaveSize;
}

// verify consistency between wave constraints and waves_per_block
// * If both wave constraints and waves_per_block are present, the computed
// number of waves per dimension should match the waves_per_block attribute.
if (hardwareConstraint && !hardwareConstraint.getWavesPerBlock().empty() &&
if (hardwareConstraint && hardwareConstraint.getWavesPerBlock() &&
!waveConstraints.empty()) {
llvm::ArrayRef<unsigned> wavesPerBlock =
hardwareConstraint.getWavesPerBlock();
llvm::ArrayRef<int32_t> wavesPerBlock =
hardwareConstraint.getWavesPerBlock().asArrayRef();
for (auto &&[symbol, waveConstraint] : waveConstraints) {
wave::WorkgroupConstraintAttr wgConstraint = workgroupConstraints[symbol];
unsigned wgDim =
Expand All @@ -335,6 +338,8 @@ verifyConstraints(ArrayAttr constraints,

// verify TilingConstraint
// * The number of tiles should be greater than or equal to one.
llvm::SmallDenseMap<wave::WaveSymbolAttr, int64_t> resolvedTilingSizes(
tilingConstraints.size());
for (auto &&[symbol, constraint] : tilingConstraints) {
std::optional<llvm::SmallVector<int64_t>> evaluated =
wave::evaluateMapWithHyperparams(constraint.getTileSize().getMap(),
Expand All @@ -351,6 +356,8 @@ verifyConstraints(ArrayAttr constraints,
"failed to resolve dimesion symbol");

int64_t resolvedTileSize = evaluated->front();
resolvedTilingSizes[symbol] = resolvedTileSize;
resolvedSizes[symbol] = resolvedTileSize;
int64_t resolvedDim = resolvedDims->front();
int64_t numTiles = resolvedDim / resolvedTileSize;
if (numTiles < 1) {
Expand All @@ -359,6 +366,38 @@ verifyConstraints(ArrayAttr constraints,
}
}

// Verify consistency between constraints and vector_shapes (when mma_type
// is absent). Each vector_shapes entry must match the resolved tile size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// is absent). Each vector_shapes entry must match the resolved tile size
// is absent). Each vector_shapes entry must be less than or equal to the resolved tile size

// from the most specific constraint for that dimension: WaveConstraint >
// WorkgroupConstraint > TilingConstraint.
if (hardwareConstraint && hardwareConstraint.getVectorShapes() &&
!hardwareConstraint.getMmaType()) {
DictionaryAttr vectorShapes = hardwareConstraint.getVectorShapes();
for (NamedAttribute dimension : vectorShapes) {
llvm::StringRef symbolName = dimension.getName().getValue();
int64_t size = llvm::cast<IntegerAttr>(dimension.getValue()).getInt();

wave::WaveSymbolAttr symbol =
wave::WaveSymbolAttr::get(hyperparams.getContext(), symbolName);

auto it = resolvedSizes.find(symbol);

if (it == resolvedSizes.end()) {
// Batch dimensions may not be present in the resolved sizes map.
continue;
}

int64_t resolvedSize = it->second;

if (size > resolvedSize) {
return emitError() << "vector_shapes entry '" << symbolName << "' ("
<< size
<< ") is greater than the resolved tile size ("
<< resolvedSize << ") for dimension: " << symbol;
}
}
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also want a check what waves_per_block in hardware constraint == evaluated workgroup tile size / evaluated wave tile size. And it may make sense to add an error or warning if one tile size is not divisible by the other.

return llvm::success();
}

Expand Down
Loading
Loading