Skip to content

Commit 283d76b

Browse files
committed
verify vector_shapes consistency with WorkgroupConstraint and TilingConstraint
Signed-off-by: Tim Gymnich <tim@gymni.ch>
1 parent 4b2f780 commit 283d76b

2 files changed

Lines changed: 58 additions & 1 deletion

File tree

water/lib/Dialect/Wave/IR/WaveDialect.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,8 @@ verifyConstraints(ArrayAttr constraints,
335335

336336
// verify TilingConstraint
337337
// * The number of tiles should be greater than or equal to one.
338+
llvm::SmallDenseMap<wave::WaveSymbolAttr, int64_t> resolvedTilingSizes(
339+
tilingConstraints.size());
338340
for (auto &&[symbol, constraint] : tilingConstraints) {
339341
std::optional<llvm::SmallVector<int64_t>> evaluated =
340342
wave::evaluateMapWithHyperparams(constraint.getTileSize().getMap(),
@@ -351,6 +353,7 @@ verifyConstraints(ArrayAttr constraints,
351353
"failed to resolve dimesion symbol");
352354

353355
int64_t resolvedTileSize = evaluated->front();
356+
resolvedTilingSizes[symbol] = resolvedTileSize;
354357
int64_t resolvedDim = resolvedDims->front();
355358
int64_t numTiles = resolvedDim / resolvedTileSize;
356359
if (numTiles < 1) {
@@ -359,6 +362,44 @@ verifyConstraints(ArrayAttr constraints,
359362
}
360363
}
361364

365+
// Verify consistency between vector_shapes and constraint tile sizes.
366+
// * When no mma_type is set, vector_shapes should match the resolved tile
367+
// sizes from the corresponding WorkgroupConstraint/TilingConstraint.
368+
// * When mma_type is present, vector_shapes derives from the MMA instruction
369+
// geometry and is expected to differ from the constraint tile sizes.
370+
if (hardwareConstraint && hardwareConstraint.getVectorShapes() &&
371+
!hardwareConstraint.getMmaType()) {
372+
DictionaryAttr vectorShapes = hardwareConstraint.getVectorShapes();
373+
for (NamedAttribute entry : vectorShapes) {
374+
llvm::StringRef dimName = entry.getName();
375+
int64_t vectorShapeSize =
376+
llvm::cast<IntegerAttr>(entry.getValue()).getInt();
377+
378+
for (auto &&[symbol, size] : resolvedWorkgroupSizes) {
379+
if (symbol.getName() != dimName)
380+
continue;
381+
if (size != vectorShapeSize) {
382+
return emitError()
383+
<< "vector_shapes entry '" << dimName << "' ("
384+
<< vectorShapeSize
385+
<< ") does not match workgroup constraint tile size (" << size
386+
<< ") for dimension: " << symbol;
387+
}
388+
}
389+
390+
for (auto &&[symbol, size] : resolvedTilingSizes) {
391+
if (symbol.getName() != dimName)
392+
continue;
393+
if (size != vectorShapeSize) {
394+
return emitError() << "vector_shapes entry '" << dimName << "' ("
395+
<< vectorShapeSize
396+
<< ") does not match tiling constraint tile size ("
397+
<< size << ") for dimension: " << symbol;
398+
}
399+
}
400+
}
401+
}
402+
362403
return llvm::success();
363404
}
364405

water/test/Dialect/Wave/attr-constraint-invalid.mlir

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: water-opt %s --allow-unregistered-dialect --water-test-wave-dialect-functions --split-input-file --verify-diagnostics
22

3-
// expected-error @below {{waves_per_block (1) should have 3 elements}}
3+
// expected-error @below {{waves_per_block should have 3 elements}}
44
#hw_constraint = #wave.hardware_constraint<threads_per_wave = 64,
55
waves_per_block = [1],
66
mma_type = #wave.mma_kind<f32_16x16x16_f16>,
@@ -252,3 +252,19 @@ func.func private @test_waves_per_block_mismatch_multi_dim() attributes { wave.h
252252
func.func private @test_waves_per_block_mismatch_y_dim() attributes { wave.hyperparameters = #hyperparams_wpb3, wave.constraints = [#wg_constraint_wpb3, #wv_constraint_wpb3, #hw_constraint_wpb3] }
253253

254254
// -----
255+
256+
#hyperparams_vs1 = #wave.hyperparameters<{M = 64, BLOCK_M = 32}>
257+
#wg_constraint_vs1 = #wave.workgroup_constraint<dim = <"M">, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = <x>>
258+
#hw_constraint_vs1 = #wave.hardware_constraint<threads_per_wave = 64, vector_shapes = {M = 16}>
259+
// expected-error @below {{vector_shapes entry 'M' (16) does not match workgroup constraint tile size (32) for dimension: #wave.symbol<"M">}}
260+
func.func private @test_vector_shapes_wg_mismatch() attributes { wave.hyperparameters = #hyperparams_vs1, wave.constraints = [#wg_constraint_vs1, #hw_constraint_vs1] }
261+
262+
// -----
263+
264+
#hyperparams_vs2 = #wave.hyperparameters<{K = 1024, BLOCK_K = 128}>
265+
#tl_constraint_vs2 = #wave.tiling_constraint<dim = <"K">, tile_size = <[#wave.symbol<"BLOCK_K">] -> (BLOCK_K)>>
266+
#hw_constraint_vs2 = #wave.hardware_constraint<threads_per_wave = 64, vector_shapes = {K = 32}>
267+
// expected-error @below {{vector_shapes entry 'K' (32) does not match tiling constraint tile size (128) for dimension: #wave.symbol<"K">}}
268+
func.func private @test_vector_shapes_tiling_mismatch() attributes { wave.hyperparameters = #hyperparams_vs2, wave.constraints = [#tl_constraint_vs2, #hw_constraint_vs2] }
269+
270+
// -----

0 commit comments

Comments
 (0)