1010#include " iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
1111#include " iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
1212#include " iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
13+ #include " iree/compiler/Codegen/Utils/GPUUtils.h"
1314#include " iree/compiler/Dialect/LinalgExt/IR/Im2colUtils.h"
1415#include " iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
1516#include " iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
@@ -50,6 +51,63 @@ static bool getBoolOption(DictionaryAttr options, StringRef name,
5051 return defaultValue;
5152}
5253
54+ static std::optional<vector::CombiningKind> matchScanCombiner (Region ®ion) {
55+ if (!region.hasOneBlock ()) {
56+ return std::nullopt ;
57+ }
58+
59+ Block &block = region.front ();
60+ if (block.getNumArguments () != 2 ) {
61+ return std::nullopt ;
62+ }
63+
64+ auto &ops = block.getOperations ();
65+ if (ops.size () != 2 ) {
66+ return std::nullopt ;
67+ }
68+
69+ Operation &firstOp = ops.front ();
70+ Operation &yieldOp = ops.back ();
71+ if (firstOp.getNumOperands () != 2 || firstOp.getNumResults () != 1 ) {
72+ return std::nullopt ;
73+ }
74+ if (yieldOp.getNumOperands () != 1 ||
75+ yieldOp.getOperand (0 ) != firstOp.getResult (0 )) {
76+ return std::nullopt ;
77+ }
78+
79+ Value arg0 = block.getArgument (0 );
80+ Value arg1 = block.getArgument (1 );
81+ Value opArg0 = firstOp.getOperand (0 );
82+ Value opArg1 = firstOp.getOperand (1 );
83+ if (opArg0 != arg0 || opArg1 != arg1) {
84+ return std::nullopt ;
85+ }
86+
87+ return llvm::TypeSwitch<Operation *, std::optional<vector::CombiningKind>>(
88+ &firstOp)
89+ .Case <arith::AddIOp, arith::AddFOp>(
90+ [](auto ) { return vector::CombiningKind::ADD; })
91+ .Case <arith::MulIOp, arith::MulFOp>(
92+ [](auto ) { return vector::CombiningKind::MUL; })
93+ .Case <arith::AndIOp>([](auto ) { return vector::CombiningKind::AND; })
94+ .Case <arith::OrIOp>([](auto ) { return vector::CombiningKind::OR; })
95+ .Case <arith::XOrIOp>([](auto ) { return vector::CombiningKind::XOR; })
96+ .Case <arith::MaxSIOp>([](auto ) { return vector::CombiningKind::MAXSI; })
97+ .Case <arith::MaxUIOp>([](auto ) { return vector::CombiningKind::MAXUI; })
98+ .Case <arith::MinSIOp>([](auto ) { return vector::CombiningKind::MINSI; })
99+ .Case <arith::MinUIOp>([](auto ) { return vector::CombiningKind::MINUI; })
100+ .Case <arith::MaximumFOp>(
101+ [](auto ) { return vector::CombiningKind::MAXIMUMF; })
102+ .Case <arith::MinimumFOp>(
103+ [](auto ) { return vector::CombiningKind::MINIMUMF; })
104+ .Case <arith::MaxNumFOp>(
105+ [](auto ) { return vector::CombiningKind::MAXNUMF; })
106+ .Case <arith::MinNumFOp>(
107+ [](auto ) { return vector::CombiningKind::MINNUMF; })
108+ .Default ([](Operation *) { return std::nullopt ; });
109+ }
110+
53111struct GatherOpVectorizationModel
54112 : VectorizableOpInterface::ExternalModel<GatherOpVectorizationModel,
55113 IREE::LinalgExt::GatherOp> {
@@ -1342,6 +1400,138 @@ struct Im2colOpVectorizationModel
13421400 return SmallVector<Value>{result};
13431401 }
13441402};
1403+
1404+ struct ScanOpVectorizationModel
1405+ : VectorizableOpInterface::ExternalModel<ScanOpVectorizationModel,
1406+ IREE::LinalgExt::ScanOp> {
1407+
1408+ bool isVectorizable (Operation *op, ArrayRef<int64_t > vectorSizes,
1409+ ArrayRef<bool > scalableDims,
1410+ DictionaryAttr options) const {
1411+ auto scanOp = cast<IREE::LinalgExt::ScanOp>(op);
1412+
1413+ // Must be able to match region to CombiningKind.
1414+ if (!matchScanCombiner (scanOp.getRegion ())) {
1415+ return false ;
1416+ }
1417+
1418+ // Scalable vectors not yet supported.
1419+ if (llvm::any_of (scalableDims, [](bool b) { return b; })) {
1420+ return false ;
1421+ }
1422+
1423+ // Without vector sizes, require static shapes.
1424+ if (vectorSizes.empty ()) {
1425+ auto inputTy = cast<ShapedType>(scanOp.getInput ().getType ());
1426+ return inputTy.hasStaticShape ();
1427+ }
1428+
1429+ return true ;
1430+ }
1431+
1432+ FailureOr<SmallVector<Value>> vectorize (Operation *op, RewriterBase &rewriter,
1433+ ArrayRef<int64_t > vectorSizes,
1434+ ArrayRef<bool > scalableDims,
1435+ DictionaryAttr options) const {
1436+ auto scanOp = cast<IREE::LinalgExt::ScanOp>(op);
1437+ Location loc = scanOp.getLoc ();
1438+ RewriterBase::InsertionGuard g (rewriter);
1439+ rewriter.setInsertionPoint (scanOp);
1440+
1441+ // Match combiner to CombiningKind.
1442+ auto kind = matchScanCombiner (scanOp.getRegion ());
1443+ if (!kind) {
1444+ return failure ();
1445+ }
1446+
1447+ // Determine vector shapes.
1448+ auto inputTy = cast<ShapedType>(scanOp.getInput ().getType ());
1449+ auto accumTy = cast<ShapedType>(scanOp.getAccumulator ().getType ());
1450+ Type elemType = inputTy.getElementType ();
1451+ int64_t inputRank = inputTy.getRank ();
1452+ int64_t scanDim = scanOp.getDimension ();
1453+
1454+ SmallVector<int64_t > inputVecShape =
1455+ vectorSizes.empty () ? llvm::to_vector (inputTy.getShape ())
1456+ : llvm::to_vector (vectorSizes);
1457+
1458+ // Accumulator shape = input shape with scan dimension dropped.
1459+ SmallVector<int64_t > accumVecShape = inputVecShape;
1460+ accumVecShape.erase (accumVecShape.begin () + scanDim);
1461+
1462+ auto inputVecTy = VectorType::get (inputVecShape, elemType);
1463+ auto accumVecTy = VectorType::get (accumVecShape, elemType);
1464+
1465+ // Determine if masking is needed (dynamic shapes or vector > tensor).
1466+ bool needsInputMasking = !inputTy.hasStaticShape () ||
1467+ !llvm::equal (inputTy.getShape (), inputVecShape);
1468+ bool needsAccumMasking = !accumTy.hasStaticShape () ||
1469+ !llvm::equal (accumTy.getShape (), accumVecShape);
1470+
1471+ Value zero = arith::ConstantIndexOp::create (rewriter, loc, 0 );
1472+ SmallVector<Value> inputIndices (inputRank, zero);
1473+ SmallVector<Value> accumIndices (accumTy.getRank (), zero);
1474+
1475+ // Read input tensor to vector.
1476+ Value padding = ub::PoisonOp::create (rewriter, loc, elemType);
1477+ Value inputVec = vector::createReadOrMaskedRead (
1478+ rewriter, loc, scanOp.getInput (), inputVecShape, padding,
1479+ /* useInBoundsInsteadOfMasking=*/ !needsInputMasking);
1480+ if (needsInputMasking) {
1481+ SmallVector<OpFoldResult> inputDims =
1482+ tensor::getMixedSizes (rewriter, loc, scanOp.getInput ());
1483+ auto inputMaskTy = VectorType::get (inputVecShape, rewriter.getI1Type ());
1484+ Value inputMask = vector::CreateMaskOp::create (
1485+ rewriter, loc, inputMaskTy,
1486+ getValueOrCreateConstantIndexOp (rewriter, loc, inputDims));
1487+
1488+ // Replace masked-off lanes with identity value.
1489+ Value identity =
1490+ getCombiningIdentityValue (loc, rewriter, *kind, inputVecTy);
1491+ inputVec =
1492+ arith::SelectOp::create (rewriter, loc, inputMask, inputVec, identity);
1493+ }
1494+
1495+ // Read accumulator (initial value) to vector.
1496+ Value accumVec = vector::createReadOrMaskedRead (
1497+ rewriter, loc, scanOp.getAccumulator (), accumVecShape, padding,
1498+ /* useInBoundsInsteadOfMasking=*/ !needsAccumMasking);
1499+ if (needsAccumMasking) {
1500+ SmallVector<OpFoldResult> accumDims =
1501+ tensor::getMixedSizes (rewriter, loc, scanOp.getAccumulator ());
1502+ auto accumMaskTy = VectorType::get (accumVecShape, rewriter.getI1Type ());
1503+ Value accumMask = vector::CreateMaskOp::create (
1504+ rewriter, loc, accumMaskTy,
1505+ getValueOrCreateConstantIndexOp (rewriter, loc, accumDims));
1506+
1507+ Value identity =
1508+ getCombiningIdentityValue (loc, rewriter, *kind, accumVecTy);
1509+ accumVec =
1510+ arith::SelectOp::create (rewriter, loc, accumMask, accumVec, identity);
1511+ }
1512+
1513+ // Create vector.scan.
1514+ auto vectorScanOp =
1515+ vector::ScanOp::create (rewriter, loc, *kind, inputVec, accumVec,
1516+ scanDim, scanOp.getInclusive ());
1517+
1518+ // Write results back to tensors.
1519+ Value output = vector::createWriteOrMaskedWrite (
1520+ rewriter, loc, vectorScanOp.getDest (),
1521+ scanOp.getOutput (), inputIndices,
1522+ /* useInBoundsInsteadOfMasking=*/ !needsInputMasking)
1523+ ->getResult (0 );
1524+
1525+ Value accum = vector::createWriteOrMaskedWrite (
1526+ rewriter, loc, vectorScanOp.getAccumulatedValue (),
1527+ scanOp.getAccumulator (), accumIndices,
1528+ /* useInBoundsInsteadOfMasking=*/ !needsAccumMasking)
1529+ ->getResult (0 );
1530+
1531+ return SmallVector<Value>{output, accum};
1532+ }
1533+ };
1534+
13451535} // namespace
13461536
13471537void registerVectorizableOpInterfaceExternalModels (DialectRegistry ®istry) {
@@ -1355,6 +1545,7 @@ void registerVectorizableOpInterfaceExternalModels(DialectRegistry ®istry) {
13551545 *ctx);
13561546 IREE::LinalgExt::Im2colOp::attachInterface<Im2colOpVectorizationModel>(
13571547 *ctx);
1548+ IREE::LinalgExt::ScanOp::attachInterface<ScanOpVectorizationModel>(*ctx);
13581549 });
13591550 registry.addExtension (+[](MLIRContext *ctx,
13601551 IREE::VectorExt::IREEVectorExtDialect *dialect) {
0 commit comments