diff --git a/docs/index.rst b/docs/index.rst index 243e420e4f..a1b03239c2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -27,6 +27,15 @@ API reference material wave/wave +Design documentation +==================== + +.. toctree:: + :maxdepth: 1 + :caption: IR Design + + ir_design + Project documentation ===================== diff --git a/docs/ir_design.rst b/docs/ir_design.rst new file mode 100644 index 0000000000..d23d195e54 --- /dev/null +++ b/docs/ir_design.rst @@ -0,0 +1,159 @@ +Vector Shapes and Hardware Constraints +====================================== + +This document describes the ``vector_shapes`` field on +``#wave.hardware_constraint`` and how it relates to ``mma_type``, +``elements_per_thread``, and the constraint system in the Water IR. + + +Overview +-------- + +``vector_shapes`` is an optional ``DictionaryAttr`` on +``#wave.hardware_constraint``. Each entry maps a dimension name (a string +matching a ``#wave.symbol``) to an integer specifying how many elements a +single wave processes along that dimension in one instance of an operation +before expansion has replicated it. + +.. code-block:: mlir + + #wave.hardware_constraint< + threads_per_wave = 64, + waves_per_block = [2, 2, 1], + mma_type = #wave.mma_kind, + vector_shapes = {M = 16, N = 16, K = 16}, + max_bits_per_load = 128> + +``vector_shapes`` is the central piece of information the compiler uses to: + +* distribute work across threads within a wave, +* determine how many elements each thread processes (``elements_per_thread``), +* compute memory access strides, and +* drive the expansion (unrolling) pass that replicates operations until the + workgroup tile is covered. + + +Where vector_shapes comes from +------------------------------- + +There are two cases, depending on whether ``mma_type`` is present. + +**When mma_type is set,** ``vector_shapes`` is derived from the MMA +instruction geometry. ``WaveMmaKindAttr::getShape`` returns the ``(M, N, K)`` +tile for the intrinsic and those sizes become the vector shape entries: + +.. code-block:: text + + mma_type = f32_16x16x16_f16 → getShape = (16, 16, 16) + vector_shapes = {M = 16, N = 16, K = 16} + +Additional entries may be provided for dimensions the MMA analysis does not +cover (e.g. a batch dimension), and in that case both ``mma_type`` and explicit +``vector_shapes`` coexist. + +**When mma_type is absent,** ``vector_shapes`` is specified directly or derived +from workgroup / tiling constraint tile sizes. In either case it must be +present for the compiler to proceed. + +In MLIR, ``vector_shapes`` entries must all be ``IntegerAttr`` values. The +verifier in ``WaveDialect.cpp`` enforces this. + + +The special value 0 +^^^^^^^^^^^^^^^^^^^ + +A vector shape of ``0`` marks a dimension as *scalar* — the wave does not tile +along it. This is used for dimensions like batch (``B``) that should not +contribute to the intra-wave data distribution: + +.. code-block:: mlir + + vector_shapes = {B = 0, M = 16, N = 16} + + +Relationship to workgroup and tiling constraints +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``vector_shapes`` and constraint tile sizes serve different purposes: + +* **Tile size** (from ``#wave.workgroup_constraint`` or + ``#wave.tiling_constraint``) is the total amount of work assigned to one + workgroup or one iteration of a reduction loop along a dimension. +* **Vector shape** is the amount of work one wave handles in a single instance + of an operation along that dimension. + +**When mma_type is present,** the vector shapes derive from the MMA geometry +and are typically smaller than the constraint tile sizes. The expansion pass +(which runs on the Python/FX side) replicates each +operation to cover the tile. For example, with ``BLOCK_M = 64``, +``waves_per_block = [2, 2, 1]``, and ``mma_type = f32_16x16x16_f16`` +(vector shape 16 for M): + +.. code-block:: text + + expansion_count = ceil(64 / (2 × 16)) = 2 + +The MLIR IR only sees the already-expanded result: two ``wave.mma`` ops along M +rather than one. The ``vector_shapes`` remain on the +``#wave.hardware_constraint`` for verification and for passes that need to +reason about the per-wave tile. + +**When mma_type is absent,** the MLIR verifier enforces that each +``vector_shapes`` entry **matches** the resolved tile size from the +corresponding ``#wave.workgroup_constraint`` or ``#wave.tiling_constraint`` for +that dimension. Unlike with mma_operations that have a fixed size, element wise operations +can operate on any number of elements_per_thread and thus don't need to be expanded multiple times. +A mismatch is a verification error: + +.. code-block:: mlir + + // ERROR: vector_shapes entry 'M' (16) does not match + // workgroup constraint tile size (32) + #wave.hardware_constraint + +This means that in non-MMA programs, there is no separate expansion step: +``vector_shapes`` equals the tile size and each operation appears exactly once +per dimension. + + +MMA kind and intrinsic shapes +------------------------------ + +``WaveMmaKindEnum`` enumerates hardware matrix multiply intrinsics. Each +variant encodes the output element type, tile shape (M×N×K), and input element +type. Examples: + +.. code-block:: mlir + + #wave.mma_kind // (M=16, N=16, K=16) + #wave.mma_kind // (M=32, N=32, K=8) + #wave.mma_kind // (M=16, N=16, K=128) + +``WaveMmaKindAttr::getShape(ctx, kind)`` returns the ``(M, N, K)`` tuple. + +The ``kind`` attribute on ``wave.mma`` may differ from the ``mma_type`` on the +hardware constraint. When ``kind`` is absent, the +``PropagateDefaultsFromConstraints`` pass fills it from the hardware +constraint's ``mma_type``. When multiple ``wave.mma`` ops exist in the same +function, each carries its own ``kind`` and its own effective vector shapes. + + +Relationship to elements_per_thread +------------------------------------- + +``elements_per_thread`` is an optional ``I64Attr`` on ``wave.read`` and +``wave.write``. It specifies how many contiguous elements a single thread +loads or stores in one operation instance: + +.. code-block:: mlir + + %0 = wave.read %mem { elements_per_thread = 8 } + : (!wave.tensor<[@M, @K] of f16, >) + -> !wave.tensor<[@M, @K] of f16, > + +``elements_per_thread`` is related to ``vector_shapes`` conceptually: the +vector shape for a dimension gives the total elements a wave handles, and +dividing by ``threads_per_wave`` (for a reduction dimension) or accounting for +thread count per workgroup dimension gives the per-thread count. The +``PropagateElementsPerThread`` pass can infer ``elements_per_thread`` from the +hardware constraint when it is not explicitly provided. diff --git a/water/include/water/Dialect/Wave/IR/WaveAttrs.td b/water/include/water/Dialect/Wave/IR/WaveAttrs.td index 948a61bbea..045815d06f 100644 --- a/water/include/water/Dialect/Wave/IR/WaveAttrs.td +++ b/water/include/water/Dialect/Wave/IR/WaveAttrs.td @@ -237,18 +237,12 @@ def HardwareConstraintAttr : AttrDef { configuration rather than fundamental hardware constraints. }]; let parameters = (ins "unsigned":$threads_per_wave, - OptionalArrayRefParameter<"unsigned">:$waves_per_block, + OptionalParameter<"::mlir::DenseI32ArrayAttr">:$waves_per_block, OptionalParameter<"::wave::WaveMmaKindAttr">:$mma_type, OptionalParameter<"::mlir::DictionaryAttr">:$vector_shapes, DefaultValuedParameter<"unsigned", "128">:$max_bits_per_load); - let assemblyFormat = [{ - `<` `threads_per_wave` `=` $threads_per_wave - (`,` `waves_per_block` `=` `[` $waves_per_block^ `]`)? - (`,` `mma_type` `=` $mma_type^)? - (`,` `vector_shapes` `=` $vector_shapes^)? - (`,` `max_bits_per_load` `=` $max_bits_per_load^)? `>` - }]; + let assemblyFormat = "`<` struct(params) `>`"; let genVerifyDecl = 1; } diff --git a/water/include/water/Dialect/Wave/IR/WaveInterfaces.h b/water/include/water/Dialect/Wave/IR/WaveInterfaces.h index 1771ba86c3..2e4b6257c3 100644 --- a/water/include/water/Dialect/Wave/IR/WaveInterfaces.h +++ b/water/include/water/Dialect/Wave/IR/WaveInterfaces.h @@ -558,8 +558,8 @@ class IndexExprsAnalysisInit { symbolConstraints; // Waves-per-block extracted from the hardware constraint or computed from - // wave constraints. Always stored here, even if copied from an attribute. - llvm::SmallVector wavesPerBlock; + // wave constraints. + llvm::SmallVector wavesPerBlock; }; // Lattice for propagating index expressions across wave dialect operations. diff --git a/water/include/water/Dialect/Wave/IR/WaveUtils.h b/water/include/water/Dialect/Wave/IR/WaveUtils.h index 61d40a3f6e..9867479837 100644 --- a/water/include/water/Dialect/Wave/IR/WaveUtils.h +++ b/water/include/water/Dialect/Wave/IR/WaveUtils.h @@ -62,7 +62,7 @@ llvm::LogicalResult computeWavesPerBlockFromConstraints( const llvm::SmallDenseMap &waveConstraints, wave::WaveHyperparameterAttr hyperparams, - llvm::SmallVectorImpl &wavesPerBlock); + llvm::SmallVectorImpl &wavesPerBlock); /// Permute the shape according to the mapping. void permuteShape(llvm::ArrayRef shape, diff --git a/water/include/water/c/Dialects.h b/water/include/water/c/Dialects.h index eab66043c9..ee96cf931f 100644 --- a/water/include/water/c/Dialects.h +++ b/water/include/water/c/Dialects.h @@ -475,7 +475,7 @@ mlirAttributeIsAHardwareConstraintAttr(MlirAttribute attr); /// Creates a new HardwareConstraintAttr MLIR_CAPI_EXPORTED MlirAttribute mlirHardwareConstraintAttrGet( MlirContext mlirCtx, unsigned threadsPerWave, size_t wavesPerBlockSize, - unsigned *wavesPerBlock, MlirAttribute mmaType, MlirAttribute vectorShapes, + int32_t *wavesPerBlock, MlirAttribute mmaType, MlirAttribute vectorShapes, unsigned maxBitsPerLoad); /// Returns the typeID of a HardwareConstraintAttr. @@ -486,7 +486,7 @@ MLIR_CAPI_EXPORTED unsigned mlirHardwareConstraintAttrGetThreadsPerWave(MlirAttribute attr); MLIR_CAPI_EXPORTED intptr_t mlirHardwareConstraintAttrGetNumWavesPerBlock(MlirAttribute attr); -MLIR_CAPI_EXPORTED unsigned +MLIR_CAPI_EXPORTED int32_t mlirHardwareConstraintAttrGetWavesPerBlockElem(MlirAttribute attr, intptr_t i); MLIR_CAPI_EXPORTED MlirAttribute mlirHardwareConstraintAttrGetMmaType(MlirAttribute attr); diff --git a/water/lib/CAPI/Dialects.cpp b/water/lib/CAPI/Dialects.cpp index 6f1397864c..3ee5409eaa 100644 --- a/water/lib/CAPI/Dialects.cpp +++ b/water/lib/CAPI/Dialects.cpp @@ -516,7 +516,7 @@ bool mlirAttributeIsAHardwareConstraintAttr(MlirAttribute attr) { MlirAttribute mlirHardwareConstraintAttrGet(MlirContext mlirCtx, unsigned threadsPerWave, - size_t wavesPerBlockSize, unsigned *wavesPerBlock, + size_t wavesPerBlockSize, int32_t *wavesPerBlock, MlirAttribute mmaType, MlirAttribute vectorShapes, unsigned maxBitsPerLoad) { MLIRContext *ctx = unwrap(mlirCtx); @@ -525,9 +525,14 @@ mlirHardwareConstraintAttrGet(MlirContext mlirCtx, unsigned threadsPerWave, auto vectorShapesAttr = llvm::cast_if_present(unwrap(vectorShapes)); + DenseI32ArrayAttr wavesPerBlockAttr; + if (wavesPerBlockSize > 0) + wavesPerBlockAttr = DenseI32ArrayAttr::get( + ctx, llvm::ArrayRef(wavesPerBlock, wavesPerBlockSize)); + return wrap(wave::HardwareConstraintAttr::get( - ctx, threadsPerWave, llvm::ArrayRef(wavesPerBlock, wavesPerBlockSize), - mmaTypeAttr, vectorShapesAttr, maxBitsPerLoad)); + ctx, threadsPerWave, wavesPerBlockAttr, mmaTypeAttr, vectorShapesAttr, + maxBitsPerLoad)); } MlirTypeID mlirWHardwareConstraintAttrGetTypeID() { @@ -539,14 +544,15 @@ unsigned mlirHardwareConstraintAttrGetThreadsPerWave(MlirAttribute attr) { .getThreadsPerWave(); } intptr_t mlirHardwareConstraintAttrGetNumWavesPerBlock(MlirAttribute attr) { - return llvm::cast(unwrap(attr)) - .getWavesPerBlock() - .size(); + DenseI32ArrayAttr wpb = + llvm::cast(unwrap(attr)).getWavesPerBlock(); + return wpb ? wpb.size() : 0; } -unsigned mlirHardwareConstraintAttrGetWavesPerBlockElem(MlirAttribute attr, - intptr_t i) { +int32_t mlirHardwareConstraintAttrGetWavesPerBlockElem(MlirAttribute attr, + intptr_t i) { return llvm::cast(unwrap(attr)) - .getWavesPerBlock()[i]; + .getWavesPerBlock() + .asArrayRef()[i]; } MlirAttribute mlirHardwareConstraintAttrGetMmaType(MlirAttribute attr) { return wrap( diff --git a/water/lib/Dialect/Wave/IR/WaveAttrs.cpp b/water/lib/Dialect/Wave/IR/WaveAttrs.cpp index 58ec7fec5a..84b207b0dc 100644 --- a/water/lib/Dialect/Wave/IR/WaveAttrs.cpp +++ b/water/lib/Dialect/Wave/IR/WaveAttrs.cpp @@ -708,12 +708,11 @@ WaveExprListAttr::verify(function_ref emitError, LogicalResult HardwareConstraintAttr::verify( function_ref emitError, unsigned threadsPerWave, - ArrayRef wavesPerBlock, WaveMmaKindAttr mmaType, + DenseI32ArrayAttr wavesPerBlock, WaveMmaKindAttr mmaType, DictionaryAttr vectorShapes, unsigned maxBitsPerLoad) { - if (!(wavesPerBlock.empty() || wavesPerBlock.size() == 3)) - return emitError() << "waves_per_block (" << wavesPerBlock - << ") should have 3 elements"; + if (wavesPerBlock && wavesPerBlock.size() != 3) + return emitError() << "waves_per_block should have 3 elements"; if (vectorShapes) { for (NamedAttribute attr : vectorShapes) { diff --git a/water/lib/Dialect/Wave/IR/WaveDialect.cpp b/water/lib/Dialect/Wave/IR/WaveDialect.cpp index 2b525b222f..335b3b765e 100644 --- a/water/lib/Dialect/Wave/IR/WaveDialect.cpp +++ b/water/lib/Dialect/Wave/IR/WaveDialect.cpp @@ -222,6 +222,7 @@ verifyConstraints(ArrayAttr constraints, // * The number of workgroups should be greater than or equal to one. llvm::SmallDenseMap resolvedWorkgroupSizes( workgroupConstraints.size()); + llvm::SmallDenseMap resolvedSizes; llvm::SmallDenseSet assignedDims; llvm::SmallDenseSet needsPrimaryDim; for (auto &&[symbol, constraint] : workgroupConstraints) { @@ -250,6 +251,7 @@ verifyConstraints(ArrayAttr constraints, int64_t workgroupSize = evaluated->front(); resolvedWorkgroupSizes[symbol] = workgroupSize; + resolvedSizes[symbol] = workgroupSize; std::optional> resolvedDims = wave::resolveSymbolNames(symbol, hyperparams); @@ -310,15 +312,16 @@ verifyConstraints(ArrayAttr constraints, << workgroupSize << " for dimension: " << symbol; } resolvedWaveCounts[symbol] = numWaves; + resolvedSizes[symbol] = resolvedWaveSize; } // verify consistency between wave constraints and waves_per_block // * If both wave constraints and waves_per_block are present, the computed // number of waves per dimension should match the waves_per_block attribute. - if (hardwareConstraint && !hardwareConstraint.getWavesPerBlock().empty() && + if (hardwareConstraint && hardwareConstraint.getWavesPerBlock() && !waveConstraints.empty()) { - llvm::ArrayRef wavesPerBlock = - hardwareConstraint.getWavesPerBlock(); + llvm::ArrayRef wavesPerBlock = + hardwareConstraint.getWavesPerBlock().asArrayRef(); for (auto &&[symbol, waveConstraint] : waveConstraints) { wave::WorkgroupConstraintAttr wgConstraint = workgroupConstraints[symbol]; unsigned wgDim = @@ -335,6 +338,8 @@ verifyConstraints(ArrayAttr constraints, // verify TilingConstraint // * The number of tiles should be greater than or equal to one. + llvm::SmallDenseMap resolvedTilingSizes( + tilingConstraints.size()); for (auto &&[symbol, constraint] : tilingConstraints) { std::optional> evaluated = wave::evaluateMapWithHyperparams(constraint.getTileSize().getMap(), @@ -351,6 +356,8 @@ verifyConstraints(ArrayAttr constraints, "failed to resolve dimesion symbol"); int64_t resolvedTileSize = evaluated->front(); + resolvedTilingSizes[symbol] = resolvedTileSize; + resolvedSizes[symbol] = resolvedTileSize; int64_t resolvedDim = resolvedDims->front(); int64_t numTiles = resolvedDim / resolvedTileSize; if (numTiles < 1) { @@ -359,6 +366,38 @@ verifyConstraints(ArrayAttr constraints, } } + // Verify consistency between constraints and vector_shapes (when mma_type + // is absent). Each vector_shapes entry must match the resolved tile size + // from the most specific constraint for that dimension: WaveConstraint > + // WorkgroupConstraint > TilingConstraint. + if (hardwareConstraint && hardwareConstraint.getVectorShapes() && + !hardwareConstraint.getMmaType()) { + DictionaryAttr vectorShapes = hardwareConstraint.getVectorShapes(); + for (NamedAttribute dimension : vectorShapes) { + llvm::StringRef symbolName = dimension.getName().getValue(); + int64_t size = llvm::cast(dimension.getValue()).getInt(); + + wave::WaveSymbolAttr symbol = + wave::WaveSymbolAttr::get(hyperparams.getContext(), symbolName); + + auto it = resolvedSizes.find(symbol); + + if (it == resolvedSizes.end()) { + // Batch dimensions may not be present in the resolved sizes map. + continue; + } + + int64_t resolvedSize = it->second; + + if (size > resolvedSize) { + return emitError() << "vector_shapes entry '" << symbolName << "' (" + << size + << ") is greater than the resolved tile size (" + << resolvedSize << ") for dimension: " << symbol; + } + } + } + return llvm::success(); } diff --git a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp index eb6b459a55..ba1b7af5b1 100644 --- a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp +++ b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp @@ -889,11 +889,11 @@ wave::IndexExprsAnalysisInit::create(Location loc, Attribute constraintsAttr, // If waves_per_block is explicitly provided, copy it to storage. Note that we // have verified they match the result of dividing block tiles with wave tiles // previously. - if (!initObject.hardwareConstraint.getWavesPerBlock().empty()) { + if (initObject.hardwareConstraint.getWavesPerBlock()) { assert(initObject.hardwareConstraint.getWavesPerBlock().size() == 3 && "expected waves_per_block to have 3 elements"); - llvm::ArrayRef explicitWpb = - initObject.hardwareConstraint.getWavesPerBlock(); + llvm::ArrayRef explicitWpb = + initObject.hardwareConstraint.getWavesPerBlock().asArrayRef(); initObject.wavesPerBlock.assign(explicitWpb); return initObject; } diff --git a/water/lib/Dialect/Wave/IR/WaveOps.cpp b/water/lib/Dialect/Wave/IR/WaveOps.cpp index 980005ea66..ab8399b580 100644 --- a/water/lib/Dialect/Wave/IR/WaveOps.cpp +++ b/water/lib/Dialect/Wave/IR/WaveOps.cpp @@ -870,7 +870,7 @@ void MmaSingleIndexExprBuilder::populate( // dimension. static llvm::LogicalResult populateMmaIndexingExpr(wave::WaveMmaKind kind, bool isAccumulator, - llvm::ArrayRef wavesPerWorkgroup, + llvm::ArrayRef wavesPerWorkgroup, int64_t threadsPerWave, wave::WaveSymbolAttr mSymbol, wave::WaveSymbolAttr nSymbol, wave::WaveSymbolAttr kSymbol, diff --git a/water/lib/Dialect/Wave/IR/WaveUtils.cpp b/water/lib/Dialect/Wave/IR/WaveUtils.cpp index 68c1ded6d9..3467eab293 100644 --- a/water/lib/Dialect/Wave/IR/WaveUtils.cpp +++ b/water/lib/Dialect/Wave/IR/WaveUtils.cpp @@ -133,7 +133,7 @@ LogicalResult wave::computeWavesPerBlockFromConstraints( const llvm::SmallDenseMap &waveConstraints, wave::WaveHyperparameterAttr hyperparams, - SmallVectorImpl &wavesPerBlock) { + SmallVectorImpl &wavesPerBlock) { // Default to 1 wave per block for each dimension, this may be recomputed // later if the corresponding constraints are provided. wavesPerBlock.assign(/*NumElts=*/3, /*Elt=*/1); @@ -168,7 +168,7 @@ LogicalResult wave::computeWavesPerBlockFromConstraints( int64_t numWaves = workgroupSize / waveSize; unsigned wgDim = static_cast(wgConstraint.getWorkgroupDim().getValue()); - wavesPerBlock[wgDim] = static_cast(numWaves); + wavesPerBlock[wgDim] = static_cast(numWaves); } return success(); diff --git a/water/lib/Dialect/Wave/Transforms/LoweringPatterns.cpp b/water/lib/Dialect/Wave/Transforms/LoweringPatterns.cpp index 7259622c33..3f5b753404 100644 --- a/water/lib/Dialect/Wave/Transforms/LoweringPatterns.cpp +++ b/water/lib/Dialect/Wave/Transforms/LoweringPatterns.cpp @@ -1327,9 +1327,12 @@ static void warnIfReductionScopeMismatch(Operation *op, bool isBlockReduction) { llvm::dyn_cast(constraint); if (!hwConstraint) continue; - ArrayRef wavesPerBlock = hwConstraint.getWavesPerBlock(); - unsigned totalWaves = 1; - for (unsigned w : wavesPerBlock) + if (!hwConstraint.getWavesPerBlock()) + continue; + llvm::ArrayRef wavesPerBlock = + hwConstraint.getWavesPerBlock().asArrayRef(); + int64_t totalWaves = 1; + for (int32_t w : wavesPerBlock) totalWaves *= w; if (isBlockReduction && totalWaves == 1) { op->emitWarning() diff --git a/water/python/WaterExtensionNanobind.cpp b/water/python/WaterExtensionNanobind.cpp index c19b9cc00e..84fe84221f 100644 --- a/water/python/WaterExtensionNanobind.cpp +++ b/water/python/WaterExtensionNanobind.cpp @@ -784,15 +784,15 @@ struct PyHardwareConstraintAttr c.def_static( "get", [](unsigned threadsPerWave, - const std::optional> &wavesPerBlock, + const std::optional> &wavesPerBlock, std::optional mmaType, std::optional vectorShapes, unsigned maxBitsPerLoad, mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext context) { - unsigned *wavesPerBlockPtr = nullptr; + int32_t *wavesPerBlockPtr = nullptr; size_t wavesPerBlockSize = 0; if (wavesPerBlock.has_value()) { - wavesPerBlockPtr = const_cast(wavesPerBlock->data()); + wavesPerBlockPtr = const_cast(wavesPerBlock->data()); wavesPerBlockSize = wavesPerBlock->size(); } @@ -812,7 +812,7 @@ struct PyHardwareConstraintAttr return mlirHardwareConstraintAttrGetThreadsPerWave(self); }); c.def_prop_ro("waves_per_block", [](MlirAttribute self) { - std::vector out; + std::vector out; intptr_t n = mlirHardwareConstraintAttrGetNumWavesPerBlock(self); out.reserve(n); for (intptr_t i = 0; i < n; ++i) diff --git a/water/test/Dialect/Wave/attr-constraint-invalid.mlir b/water/test/Dialect/Wave/attr-constraint-invalid.mlir index 127b0d9f4e..0f9c610c11 100644 --- a/water/test/Dialect/Wave/attr-constraint-invalid.mlir +++ b/water/test/Dialect/Wave/attr-constraint-invalid.mlir @@ -1,6 +1,6 @@ // RUN: water-opt %s --allow-unregistered-dialect --water-test-wave-dialect-functions --split-input-file --verify-diagnostics -// expected-error @below {{waves_per_block (1) should have 3 elements}} +// expected-error @below {{waves_per_block should have 3 elements}} #hw_constraint = #wave.hardware_constraint, @@ -252,3 +252,28 @@ func.func private @test_waves_per_block_mismatch_multi_dim() attributes { wave.h func.func private @test_waves_per_block_mismatch_y_dim() attributes { wave.hyperparameters = #hyperparams_wpb3, wave.constraints = [#wg_constraint_wpb3, #wv_constraint_wpb3, #hw_constraint_wpb3] } // ----- + +#hyperparams_vs1 = #wave.hyperparameters<{M = 64, BLOCK_M = 32}> +#wg_constraint_vs1 = #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = > +#hw_constraint_vs1 = #wave.hardware_constraint +// expected-error @below {{vector_shapes entry 'M' (64) is greater than the resolved tile size (32) for dimension: #wave.symbol<"M">}} +func.func private @test_vector_shapes_wg_greater() attributes { wave.hyperparameters = #hyperparams_vs1, wave.constraints = [#wg_constraint_vs1, #hw_constraint_vs1] } + +// ----- + +#hyperparams_vs2 = #wave.hyperparameters<{K = 1024, BLOCK_K = 128}> +#tl_constraint_vs2 = #wave.tiling_constraint, tile_size = <[#wave.symbol<"BLOCK_K">] -> (BLOCK_K)>> +#hw_constraint_vs2 = #wave.hardware_constraint +// expected-error @below {{vector_shapes entry 'K' (256) is greater than the resolved tile size (128) for dimension: #wave.symbol<"K">}} +func.func private @test_vector_shapes_tiling_greater() attributes { wave.hyperparameters = #hyperparams_vs2, wave.constraints = [#tl_constraint_vs2, #hw_constraint_vs2] } + +// ----- + +#hyperparams_vs3 = #wave.hyperparameters<{M = 1024, BLOCK_M = 128}> +#wg_constraint_vs3 = #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = > +#wv_constraint_vs3 = #wave.wave_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M floordiv 4)>> +#hw_constraint_vs3 = #wave.hardware_constraint +// expected-error @below {{vector_shapes entry 'M' (64) is greater than the resolved tile size (32) for dimension: #wave.symbol<"M">}} +func.func private @test_vector_shapes_wave_greater() attributes { wave.hyperparameters = #hyperparams_vs3, wave.constraints = [#wg_constraint_vs3, #wv_constraint_vs3, #hw_constraint_vs3] } + +// ----- diff --git a/water/test/Dialect/Wave/attr-constraint.mlir b/water/test/Dialect/Wave/attr-constraint.mlir index 8b7351956e..87b83fe20b 100644 --- a/water/test/Dialect/Wave/attr-constraint.mlir +++ b/water/test/Dialect/Wave/attr-constraint.mlir @@ -108,3 +108,65 @@ func.func private @test_waves_per_block_no_wave_constraints() attributes { wave. #wv_constraint_wpb_valid4 = #wave.wave_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M floordiv 4)>> #hw_constraint_wpb_valid4 = #wave.hardware_constraint func.func private @test_wave_constraints_no_waves_per_block() attributes { wave.hyperparameters = #hyperparams_wpb_valid4, wave.constraints = [#wg_constraint_wpb_valid4, #wv_constraint_wpb_valid4, #hw_constraint_wpb_valid4] } + +// CHECK-LABEL: @test_vector_shapes_match_wg +// CHECK: #wave.hardware_constraint +#hyperparams_vs_valid1 = #wave.hyperparameters<{M = 1024, BLOCK_M = 128}> +#wg_constraint_vs_valid1 = #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = > +#hw_constraint_vs_valid1 = #wave.hardware_constraint +func.func private @test_vector_shapes_match_wg() attributes { wave.hyperparameters = #hyperparams_vs_valid1, wave.constraints = [#wg_constraint_vs_valid1, #hw_constraint_vs_valid1] } + +// CHECK-LABEL: @test_vector_shapes_match_wave +// CHECK: #wave.hardware_constraint +#hyperparams_vs_valid2 = #wave.hyperparameters<{M = 1024, BLOCK_M = 128}> +#wg_constraint_vs_valid2 = #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = > +#wv_constraint_vs_valid2 = #wave.wave_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M floordiv 4)>> +#hw_constraint_vs_valid2 = #wave.hardware_constraint +func.func private @test_vector_shapes_match_wave() attributes { wave.hyperparameters = #hyperparams_vs_valid2, wave.constraints = [#wg_constraint_vs_valid2, #wv_constraint_vs_valid2, #hw_constraint_vs_valid2] } + +// CHECK-LABEL: @test_vector_shapes_match_tiling +// CHECK: #wave.hardware_constraint +#hyperparams_vs_valid3 = #wave.hyperparameters<{K = 1024, BLOCK_K = 128}> +#tl_constraint_vs_valid3 = #wave.tiling_constraint, tile_size = <[#wave.symbol<"BLOCK_K">] -> (BLOCK_K)>> +#hw_constraint_vs_valid3 = #wave.hardware_constraint +func.func private @test_vector_shapes_match_tiling() attributes { wave.hyperparameters = #hyperparams_vs_valid3, wave.constraints = [#tl_constraint_vs_valid3, #hw_constraint_vs_valid3] } + +// CHECK-LABEL: @test_vector_shapes_with_mma_type +// CHECK: #wave.hardware_constraint, vector_shapes = {K = 99 : i64, M = 99 : i64}> +#hyperparams_vs_valid4 = #wave.hyperparameters<{M = 1024, K = 1024, BLOCK_M = 128, BLOCK_K = 128}> +#wg_constraint_vs_valid4 = #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = > +#tl_constraint_vs_valid4 = #wave.tiling_constraint, tile_size = <[#wave.symbol<"BLOCK_K">] -> (BLOCK_K)>> +#hw_constraint_vs_valid4 = #wave.hardware_constraint, vector_shapes = {M = 99, K = 99}> +func.func private @test_vector_shapes_with_mma_type() attributes { wave.hyperparameters = #hyperparams_vs_valid4, wave.constraints = [#wg_constraint_vs_valid4, #tl_constraint_vs_valid4, #hw_constraint_vs_valid4] } + +// CHECK-LABEL: @test_vector_shapes_multi_dim +// CHECK: #wave.hardware_constraint +#hyperparams_vs_valid5 = #wave.hyperparameters<{M = 1024, N = 512, K = 1024, BLOCK_M = 128, BLOCK_N = 64, BLOCK_K = 64}> +#wg_constraint_vs_valid5_m = #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = > +#wg_constraint_vs_valid5_n = #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_N">] -> (BLOCK_N)>, workgroup_dim = > +#wv_constraint_vs_valid5_m = #wave.wave_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M floordiv 4)>> +#tl_constraint_vs_valid5_k = #wave.tiling_constraint, tile_size = <[#wave.symbol<"BLOCK_K">] -> (BLOCK_K)>> +#hw_constraint_vs_valid5 = #wave.hardware_constraint +func.func private @test_vector_shapes_multi_dim() attributes { wave.hyperparameters = #hyperparams_vs_valid5, wave.constraints = [#wg_constraint_vs_valid5_m, #wg_constraint_vs_valid5_n, #wv_constraint_vs_valid5_m, #tl_constraint_vs_valid5_k, #hw_constraint_vs_valid5] } + +// CHECK-LABEL: @test_vector_shapes_smaller_than_wg +// CHECK: #wave.hardware_constraint +#hyperparams_vs_valid6 = #wave.hyperparameters<{M = 64, BLOCK_M = 32}> +#wg_constraint_vs_valid6 = #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = > +#hw_constraint_vs_valid6 = #wave.hardware_constraint +func.func private @test_vector_shapes_smaller_than_wg() attributes { wave.hyperparameters = #hyperparams_vs_valid6, wave.constraints = [#wg_constraint_vs_valid6, #hw_constraint_vs_valid6] } + +// CHECK-LABEL: @test_vector_shapes_smaller_than_tiling +// CHECK: #wave.hardware_constraint +#hyperparams_vs_valid7 = #wave.hyperparameters<{K = 1024, BLOCK_K = 128}> +#tl_constraint_vs_valid7 = #wave.tiling_constraint, tile_size = <[#wave.symbol<"BLOCK_K">] -> (BLOCK_K)>> +#hw_constraint_vs_valid7 = #wave.hardware_constraint +func.func private @test_vector_shapes_smaller_than_tiling() attributes { wave.hyperparameters = #hyperparams_vs_valid7, wave.constraints = [#tl_constraint_vs_valid7, #hw_constraint_vs_valid7] } + +// CHECK-LABEL: @test_vector_shapes_smaller_than_wave +// CHECK: #wave.hardware_constraint +#hyperparams_vs_valid8 = #wave.hyperparameters<{M = 1024, BLOCK_M = 128}> +#wg_constraint_vs_valid8 = #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = > +#wv_constraint_vs_valid8 = #wave.wave_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M floordiv 4)>> +#hw_constraint_vs_valid8 = #wave.hardware_constraint +func.func private @test_vector_shapes_smaller_than_wave() attributes { wave.hyperparameters = #hyperparams_vs_valid8, wave.constraints = [#wg_constraint_vs_valid8, #wv_constraint_vs_valid8, #hw_constraint_vs_valid8] }