Skip to content

Commit d7d4d0d

Browse files
authored
[Codegen] Vectorize linalg_ext.scan to vector.scan (#24187)
During `GenericVectorization`, vectorize `linalg_ext.scan` to `vector.scan`. If masking is required, it is introduced directly as masked `transfer_read/write` and an `arith.select` selecting between the actual data for unmasked elements and the identity for the combining operation in the scan (e.g., zero for add) for masked elements. `linalg_ext.scan` expresses the combiner as a region, where as `vector.scan` uses a fixed set of combiners as enum attribute. Therefore, we try to match the content of the region of the `linalg_ext.scan` operation against the set of supported combiners. This is part of #24186. Assisted-by: Claude Code and Codex --------- Signed-off-by: Lukas Sommer <lukas.sommer@amd.com>
1 parent d84dc73 commit d7d4d0d

5 files changed

Lines changed: 305 additions & 0 deletions

File tree

compiler/src/iree/compiler/Codegen/Common/test/generic_vectorization_masked_configured.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,32 @@ func.func @configured_zero_vector_size_falls_back_to_inference(
173173
}
174174
// CHECK-LABEL: func.func @configured_zero_vector_size_falls_back_to_inference(
175175
// CHECK: arith.addf {{.*}} : vector<4x1xf32>
176+
177+
// -----
178+
179+
#scan_masked_config = #iree_cpu.lowering_config<vector_common_parallel = [8, 16]>
180+
181+
func.func @vectorize_scan_masked_configured(
182+
%input: tensor<?x?xf32>,
183+
%output: tensor<?x?xf32>,
184+
%accum: tensor<?xf32>) -> (tensor<?x?xf32>, tensor<?xf32>) {
185+
%0:2 = iree_linalg_ext.scan {lowering_config = #scan_masked_config}
186+
dimension(1) inclusive(true)
187+
ins(%input : tensor<?x?xf32>)
188+
outs(%output, %accum : tensor<?x?xf32>, tensor<?xf32>) {
189+
^bb0(%arg0: f32, %arg1: f32):
190+
%sum = arith.addf %arg0, %arg1 : f32
191+
iree_linalg_ext.yield %sum : f32
192+
} -> tensor<?x?xf32>, tensor<?xf32>
193+
return %0#0, %0#1 : tensor<?x?xf32>, tensor<?xf32>
194+
}
195+
// CHECK-LABEL: func.func @vectorize_scan_masked_configured(
196+
// CHECK: vector.create_mask {{.*}} : vector<8x16xi1>
197+
// CHECK: vector.transfer_read {{.*}} : tensor<?x?xf32>, vector<8x16xf32>
198+
// CHECK: arith.select {{.*}} : vector<8x16xi1>, vector<8x16xf32>
199+
// CHECK: vector.create_mask {{.*}} : vector<8xi1>
200+
// CHECK: vector.transfer_read {{.*}} : tensor<?xf32>, vector<8xf32>
201+
// CHECK: arith.select {{.*}} : vector<8xi1>, vector<8xf32>
202+
// CHECK: vector.scan <add>, {{.*}} {inclusive = true, reduction_dim = 1 : i64}
203+
// CHECK: vector.transfer_write {{.*}} : vector<8x16xf32>, tensor<?x?xf32>
204+
// CHECK: vector.transfer_write {{.*}} : vector<8xf32>, tensor<?xf32>

compiler/src/iree/compiler/Codegen/Common/test/generic_vectorization_unmasked.mlir

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,86 @@ func.func @single_static_pack_infer_vector_size(%arg0: tensor<101x201xi8>, %arg1
5656
// TODO: Support non-masking path.
5757
// CHECK-LABEL: func.func @single_static_pack_infer_vector_size
5858
// CHECK: linalg.pack
59+
60+
// -----
61+
62+
// CHECK-LABEL: func.func @vectorize_scan_add_inclusive
63+
func.func @vectorize_scan_add_inclusive(
64+
%input: tensor<8xf32>,
65+
%output: tensor<8xf32>,
66+
%accum: tensor<f32>) -> (tensor<8xf32>, tensor<f32>) {
67+
%0:2 = iree_linalg_ext.scan
68+
dimension(0) inclusive(true)
69+
ins(%input : tensor<8xf32>)
70+
outs(%output, %accum : tensor<8xf32>, tensor<f32>) {
71+
^bb0(%arg0: f32, %arg1: f32):
72+
%sum = arith.addf %arg0, %arg1 : f32
73+
iree_linalg_ext.yield %sum : f32
74+
} -> tensor<8xf32>, tensor<f32>
75+
return %0#0, %0#1 : tensor<8xf32>, tensor<f32>
76+
}
77+
// CHECK: %[[READ:.+]] = vector.transfer_read
78+
// CHECK: %[[INIT:.+]] = vector.transfer_read
79+
// CHECK: %[[DEST:.+]], %{{.+}} = vector.scan <add>, %[[READ]], %[[INIT]]
80+
// CHECK-SAME: inclusive = true
81+
// CHECK: vector.transfer_write %[[DEST]]
82+
// CHECK: vector.transfer_write
83+
84+
// -----
85+
86+
// CHECK-LABEL: func.func @vectorize_scan_mul_exclusive
87+
func.func @vectorize_scan_mul_exclusive(
88+
%input: tensor<16xi32>,
89+
%output: tensor<16xi32>,
90+
%accum: tensor<i32>) -> (tensor<16xi32>, tensor<i32>) {
91+
%0:2 = iree_linalg_ext.scan
92+
dimension(0) inclusive(false)
93+
ins(%input : tensor<16xi32>)
94+
outs(%output, %accum : tensor<16xi32>, tensor<i32>) {
95+
^bb0(%arg0: i32, %arg1: i32):
96+
%prod = arith.muli %arg0, %arg1 : i32
97+
iree_linalg_ext.yield %prod : i32
98+
} -> tensor<16xi32>, tensor<i32>
99+
return %0#0, %0#1 : tensor<16xi32>, tensor<i32>
100+
}
101+
// CHECK: vector.scan <mul>
102+
// CHECK-SAME: inclusive = false
103+
104+
// -----
105+
106+
// CHECK-LABEL: func.func @vectorize_scan_2d
107+
func.func @vectorize_scan_2d(
108+
%input: tensor<4x8xf32>,
109+
%output: tensor<4x8xf32>,
110+
%accum: tensor<4xf32>) -> (tensor<4x8xf32>, tensor<4xf32>) {
111+
%0:2 = iree_linalg_ext.scan
112+
dimension(1) inclusive(true)
113+
ins(%input : tensor<4x8xf32>)
114+
outs(%output, %accum : tensor<4x8xf32>, tensor<4xf32>) {
115+
^bb0(%arg0: f32, %arg1: f32):
116+
%sum = arith.addf %arg0, %arg1 : f32
117+
iree_linalg_ext.yield %sum : f32
118+
} -> tensor<4x8xf32>, tensor<4xf32>
119+
return %0#0, %0#1 : tensor<4x8xf32>, tensor<4xf32>
120+
}
121+
// CHECK: vector.scan <add>
122+
// CHECK-SAME: reduction_dim = 1
123+
124+
// -----
125+
126+
// CHECK-LABEL: func.func @vectorize_scan_maxsi
127+
func.func @vectorize_scan_maxsi(
128+
%input: tensor<8xi32>,
129+
%output: tensor<8xi32>,
130+
%accum: tensor<i32>) -> (tensor<8xi32>, tensor<i32>) {
131+
%0:2 = iree_linalg_ext.scan
132+
dimension(0) inclusive(true)
133+
ins(%input : tensor<8xi32>)
134+
outs(%output, %accum : tensor<8xi32>, tensor<i32>) {
135+
^bb0(%arg0: i32, %arg1: i32):
136+
%max = arith.maxsi %arg0, %arg1 : i32
137+
iree_linalg_ext.yield %max : i32
138+
} -> tensor<8xi32>, tensor<i32>
139+
return %0#0, %0#1 : tensor<8xi32>, tensor<i32>
140+
}
141+
// CHECK: vector.scan <maxsi>

compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ iree_compiler_cc_library(
236236
":VectorizableOpInterfaceGen",
237237
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
238238
"//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect",
239+
"//compiler/src/iree/compiler/Codegen/Utils",
239240
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
240241
"//compiler/src/iree/compiler/Dialect/LinalgExt/Utils",
241242
"//compiler/src/iree/compiler/Utils",

compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ iree_cc_library(
188188
MLIRVectorUtils
189189
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
190190
iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect
191+
iree::compiler::Codegen::Utils
191192
iree::compiler::Dialect::LinalgExt::IR
192193
iree::compiler::Dialect::LinalgExt::Utils
193194
iree::compiler::Utils

compiler/src/iree/compiler/Codegen/Interfaces/VectorizableOpInterface.cpp

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
1111
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
1212
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
13+
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
1314
#include "iree/compiler/Dialect/LinalgExt/IR/Im2colUtils.h"
1415
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
1516
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
@@ -50,6 +51,63 @@ static bool getBoolOption(DictionaryAttr options, StringRef name,
5051
return defaultValue;
5152
}
5253

54+
static std::optional<vector::CombiningKind> matchScanCombiner(Region &region) {
55+
if (!region.hasOneBlock()) {
56+
return std::nullopt;
57+
}
58+
59+
Block &block = region.front();
60+
if (block.getNumArguments() != 2) {
61+
return std::nullopt;
62+
}
63+
64+
auto &ops = block.getOperations();
65+
if (ops.size() != 2) {
66+
return std::nullopt;
67+
}
68+
69+
Operation &firstOp = ops.front();
70+
Operation &yieldOp = ops.back();
71+
if (firstOp.getNumOperands() != 2 || firstOp.getNumResults() != 1) {
72+
return std::nullopt;
73+
}
74+
if (yieldOp.getNumOperands() != 1 ||
75+
yieldOp.getOperand(0) != firstOp.getResult(0)) {
76+
return std::nullopt;
77+
}
78+
79+
Value arg0 = block.getArgument(0);
80+
Value arg1 = block.getArgument(1);
81+
Value opArg0 = firstOp.getOperand(0);
82+
Value opArg1 = firstOp.getOperand(1);
83+
if (opArg0 != arg0 || opArg1 != arg1) {
84+
return std::nullopt;
85+
}
86+
87+
return llvm::TypeSwitch<Operation *, std::optional<vector::CombiningKind>>(
88+
&firstOp)
89+
.Case<arith::AddIOp, arith::AddFOp>(
90+
[](auto) { return vector::CombiningKind::ADD; })
91+
.Case<arith::MulIOp, arith::MulFOp>(
92+
[](auto) { return vector::CombiningKind::MUL; })
93+
.Case<arith::AndIOp>([](auto) { return vector::CombiningKind::AND; })
94+
.Case<arith::OrIOp>([](auto) { return vector::CombiningKind::OR; })
95+
.Case<arith::XOrIOp>([](auto) { return vector::CombiningKind::XOR; })
96+
.Case<arith::MaxSIOp>([](auto) { return vector::CombiningKind::MAXSI; })
97+
.Case<arith::MaxUIOp>([](auto) { return vector::CombiningKind::MAXUI; })
98+
.Case<arith::MinSIOp>([](auto) { return vector::CombiningKind::MINSI; })
99+
.Case<arith::MinUIOp>([](auto) { return vector::CombiningKind::MINUI; })
100+
.Case<arith::MaximumFOp>(
101+
[](auto) { return vector::CombiningKind::MAXIMUMF; })
102+
.Case<arith::MinimumFOp>(
103+
[](auto) { return vector::CombiningKind::MINIMUMF; })
104+
.Case<arith::MaxNumFOp>(
105+
[](auto) { return vector::CombiningKind::MAXNUMF; })
106+
.Case<arith::MinNumFOp>(
107+
[](auto) { return vector::CombiningKind::MINNUMF; })
108+
.Default([](Operation *) { return std::nullopt; });
109+
}
110+
53111
struct GatherOpVectorizationModel
54112
: VectorizableOpInterface::ExternalModel<GatherOpVectorizationModel,
55113
IREE::LinalgExt::GatherOp> {
@@ -1342,6 +1400,138 @@ struct Im2colOpVectorizationModel
13421400
return SmallVector<Value>{result};
13431401
}
13441402
};
1403+
1404+
struct ScanOpVectorizationModel
1405+
: VectorizableOpInterface::ExternalModel<ScanOpVectorizationModel,
1406+
IREE::LinalgExt::ScanOp> {
1407+
1408+
bool isVectorizable(Operation *op, ArrayRef<int64_t> vectorSizes,
1409+
ArrayRef<bool> scalableDims,
1410+
DictionaryAttr options) const {
1411+
auto scanOp = cast<IREE::LinalgExt::ScanOp>(op);
1412+
1413+
// Must be able to match region to CombiningKind.
1414+
if (!matchScanCombiner(scanOp.getRegion())) {
1415+
return false;
1416+
}
1417+
1418+
// Scalable vectors not yet supported.
1419+
if (llvm::any_of(scalableDims, [](bool b) { return b; })) {
1420+
return false;
1421+
}
1422+
1423+
// Without vector sizes, require static shapes.
1424+
if (vectorSizes.empty()) {
1425+
auto inputTy = cast<ShapedType>(scanOp.getInput().getType());
1426+
return inputTy.hasStaticShape();
1427+
}
1428+
1429+
return true;
1430+
}
1431+
1432+
FailureOr<SmallVector<Value>> vectorize(Operation *op, RewriterBase &rewriter,
1433+
ArrayRef<int64_t> vectorSizes,
1434+
ArrayRef<bool> scalableDims,
1435+
DictionaryAttr options) const {
1436+
auto scanOp = cast<IREE::LinalgExt::ScanOp>(op);
1437+
Location loc = scanOp.getLoc();
1438+
RewriterBase::InsertionGuard g(rewriter);
1439+
rewriter.setInsertionPoint(scanOp);
1440+
1441+
// Match combiner to CombiningKind.
1442+
auto kind = matchScanCombiner(scanOp.getRegion());
1443+
if (!kind) {
1444+
return failure();
1445+
}
1446+
1447+
// Determine vector shapes.
1448+
auto inputTy = cast<ShapedType>(scanOp.getInput().getType());
1449+
auto accumTy = cast<ShapedType>(scanOp.getAccumulator().getType());
1450+
Type elemType = inputTy.getElementType();
1451+
int64_t inputRank = inputTy.getRank();
1452+
int64_t scanDim = scanOp.getDimension();
1453+
1454+
SmallVector<int64_t> inputVecShape =
1455+
vectorSizes.empty() ? llvm::to_vector(inputTy.getShape())
1456+
: llvm::to_vector(vectorSizes);
1457+
1458+
// Accumulator shape = input shape with scan dimension dropped.
1459+
SmallVector<int64_t> accumVecShape = inputVecShape;
1460+
accumVecShape.erase(accumVecShape.begin() + scanDim);
1461+
1462+
auto inputVecTy = VectorType::get(inputVecShape, elemType);
1463+
auto accumVecTy = VectorType::get(accumVecShape, elemType);
1464+
1465+
// Determine if masking is needed (dynamic shapes or vector > tensor).
1466+
bool needsInputMasking = !inputTy.hasStaticShape() ||
1467+
!llvm::equal(inputTy.getShape(), inputVecShape);
1468+
bool needsAccumMasking = !accumTy.hasStaticShape() ||
1469+
!llvm::equal(accumTy.getShape(), accumVecShape);
1470+
1471+
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1472+
SmallVector<Value> inputIndices(inputRank, zero);
1473+
SmallVector<Value> accumIndices(accumTy.getRank(), zero);
1474+
1475+
// Read input tensor to vector.
1476+
Value padding = ub::PoisonOp::create(rewriter, loc, elemType);
1477+
Value inputVec = vector::createReadOrMaskedRead(
1478+
rewriter, loc, scanOp.getInput(), inputVecShape, padding,
1479+
/*useInBoundsInsteadOfMasking=*/!needsInputMasking);
1480+
if (needsInputMasking) {
1481+
SmallVector<OpFoldResult> inputDims =
1482+
tensor::getMixedSizes(rewriter, loc, scanOp.getInput());
1483+
auto inputMaskTy = VectorType::get(inputVecShape, rewriter.getI1Type());
1484+
Value inputMask = vector::CreateMaskOp::create(
1485+
rewriter, loc, inputMaskTy,
1486+
getValueOrCreateConstantIndexOp(rewriter, loc, inputDims));
1487+
1488+
// Replace masked-off lanes with identity value.
1489+
Value identity =
1490+
getCombiningIdentityValue(loc, rewriter, *kind, inputVecTy);
1491+
inputVec =
1492+
arith::SelectOp::create(rewriter, loc, inputMask, inputVec, identity);
1493+
}
1494+
1495+
// Read accumulator (initial value) to vector.
1496+
Value accumVec = vector::createReadOrMaskedRead(
1497+
rewriter, loc, scanOp.getAccumulator(), accumVecShape, padding,
1498+
/*useInBoundsInsteadOfMasking=*/!needsAccumMasking);
1499+
if (needsAccumMasking) {
1500+
SmallVector<OpFoldResult> accumDims =
1501+
tensor::getMixedSizes(rewriter, loc, scanOp.getAccumulator());
1502+
auto accumMaskTy = VectorType::get(accumVecShape, rewriter.getI1Type());
1503+
Value accumMask = vector::CreateMaskOp::create(
1504+
rewriter, loc, accumMaskTy,
1505+
getValueOrCreateConstantIndexOp(rewriter, loc, accumDims));
1506+
1507+
Value identity =
1508+
getCombiningIdentityValue(loc, rewriter, *kind, accumVecTy);
1509+
accumVec =
1510+
arith::SelectOp::create(rewriter, loc, accumMask, accumVec, identity);
1511+
}
1512+
1513+
// Create vector.scan.
1514+
auto vectorScanOp =
1515+
vector::ScanOp::create(rewriter, loc, *kind, inputVec, accumVec,
1516+
scanDim, scanOp.getInclusive());
1517+
1518+
// Write results back to tensors.
1519+
Value output = vector::createWriteOrMaskedWrite(
1520+
rewriter, loc, vectorScanOp.getDest(),
1521+
scanOp.getOutput(), inputIndices,
1522+
/*useInBoundsInsteadOfMasking=*/!needsInputMasking)
1523+
->getResult(0);
1524+
1525+
Value accum = vector::createWriteOrMaskedWrite(
1526+
rewriter, loc, vectorScanOp.getAccumulatedValue(),
1527+
scanOp.getAccumulator(), accumIndices,
1528+
/*useInBoundsInsteadOfMasking=*/!needsAccumMasking)
1529+
->getResult(0);
1530+
1531+
return SmallVector<Value>{output, accum};
1532+
}
1533+
};
1534+
13451535
} // namespace
13461536

13471537
void registerVectorizableOpInterfaceExternalModels(DialectRegistry &registry) {
@@ -1355,6 +1545,7 @@ void registerVectorizableOpInterfaceExternalModels(DialectRegistry &registry) {
13551545
*ctx);
13561546
IREE::LinalgExt::Im2colOp::attachInterface<Im2colOpVectorizationModel>(
13571547
*ctx);
1548+
IREE::LinalgExt::ScanOp::attachInterface<ScanOpVectorizationModel>(*ctx);
13581549
});
13591550
registry.addExtension(+[](MLIRContext *ctx,
13601551
IREE::VectorExt::IREEVectorExtDialect *dialect) {

0 commit comments

Comments
 (0)