@@ -335,6 +335,8 @@ verifyConstraints(ArrayAttr constraints,
335335
336336 // verify TilingConstraint
337337 // * The number of tiles should be greater than or equal to one.
338+ llvm::SmallDenseMap<wave::WaveSymbolAttr, int64_t > resolvedTilingSizes (
339+ tilingConstraints.size ());
338340 for (auto &&[symbol, constraint] : tilingConstraints) {
339341 std::optional<llvm::SmallVector<int64_t >> evaluated =
340342 wave::evaluateMapWithHyperparams (constraint.getTileSize ().getMap (),
@@ -351,6 +353,7 @@ verifyConstraints(ArrayAttr constraints,
351353 " failed to resolve dimesion symbol" );
352354
353355 int64_t resolvedTileSize = evaluated->front ();
356+ resolvedTilingSizes[symbol] = resolvedTileSize;
354357 int64_t resolvedDim = resolvedDims->front ();
355358 int64_t numTiles = resolvedDim / resolvedTileSize;
356359 if (numTiles < 1 ) {
@@ -359,6 +362,44 @@ verifyConstraints(ArrayAttr constraints,
359362 }
360363 }
361364
365+ // Verify consistency between vector_shapes and constraint tile sizes.
366+ // * When no mma_type is set, vector_shapes should match the resolved tile
367+ // sizes from the corresponding WorkgroupConstraint/TilingConstraint.
368+ // * When mma_type is present, vector_shapes derives from the MMA instruction
369+ // geometry and is expected to differ from the constraint tile sizes.
370+ if (hardwareConstraint && hardwareConstraint.getVectorShapes () &&
371+ !hardwareConstraint.getMmaType ()) {
372+ DictionaryAttr vectorShapes = hardwareConstraint.getVectorShapes ();
373+ for (NamedAttribute entry : vectorShapes) {
374+ llvm::StringRef dimName = entry.getName ();
375+ int64_t vectorShapeSize =
376+ llvm::cast<IntegerAttr>(entry.getValue ()).getInt ();
377+
378+ for (auto &&[symbol, size] : resolvedWorkgroupSizes) {
379+ if (symbol.getName () != dimName)
380+ continue ;
381+ if (size != vectorShapeSize) {
382+ return emitError ()
383+ << " vector_shapes entry '" << dimName << " ' ("
384+ << vectorShapeSize
385+ << " ) does not match workgroup constraint tile size (" << size
386+ << " ) for dimension: " << symbol;
387+ }
388+ }
389+
390+ for (auto &&[symbol, size] : resolvedTilingSizes) {
391+ if (symbol.getName () != dimName)
392+ continue ;
393+ if (size != vectorShapeSize) {
394+ return emitError () << " vector_shapes entry '" << dimName << " ' ("
395+ << vectorShapeSize
396+ << " ) does not match tiling constraint tile size ("
397+ << size << " ) for dimension: " << symbol;
398+ }
399+ }
400+ }
401+ }
402+
362403 return llvm::success ();
363404}
364405
0 commit comments