Skip to content

Commit 1e8d55e

Browse files
Keep masked online attention finalization fused
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
1 parent d41c122 commit 1e8d55e

4 files changed

Lines changed: 17 additions & 64 deletions

File tree

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
99
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
1010
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
11-
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
1211
#include "mlir/Dialect/Arith/IR/Arith.h"
1312
#include "mlir/Dialect/Arith/Utils/Utils.h"
1413
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -124,40 +123,32 @@ void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp,
124123

125124
Value x = onlineAttn.getResult(0);
126125
Value sum = onlineAttn.getResult(2);
127-
Value fullyMaskedRows;
128-
if (mask) {
129-
fullyMaskedRows = createFullyMaskedRowsFromMask(
130-
rewriter, loc, *attnOp.getMaskMap(), sumMap, rowRedSize, *mask);
131-
}
126+
bool hasMask = static_cast<bool>(mask);
132127

133128
// Finalize online attention. With a mask, fully-masked rows can have
134-
// `sum == 0` and `x == 0`; zero those rows after normalization to produce 0
135-
// instead of NaN.
129+
// `sum == 0` and `x == 0`; guard that case in the finalization loop to
130+
// produce 0 instead of NaN. Keep this fused with finalization because the
131+
// online-attention lowering path expects a single finalization consumer.
136132

137133
// Compress the indexing maps.
138-
SmallVector<Value> inputs = {sum, x};
139-
SmallVector<AffineMap> indexingMapsForFinalization = {sumMap, accMap};
140-
if (fullyMaskedRows) {
141-
inputs.push_back(fullyMaskedRows);
142-
indexingMapsForFinalization.push_back(sumMap);
143-
}
144-
indexingMapsForFinalization.push_back(accMap);
145134
SmallVector<AffineMap> compressedMaps =
146-
compressUnusedDims(indexingMapsForFinalization);
135+
compressUnusedDims(SmallVector<AffineMap>{sumMap, accMap, accMap});
147136

148137
SmallVector<utils::IteratorType> iteratorTypes(compressedMaps[0].getNumDims(),
149138
utils::IteratorType::parallel);
150139

151140
auto genericOp = linalg::GenericOp::create(
152-
rewriter, loc, attnOp.getOutput().getType(), inputs, attnOp.getOutput(),
153-
compressedMaps, iteratorTypes,
141+
rewriter, loc, attnOp.getOutput().getType(), ValueRange{sum, x},
142+
attnOp.getOutput(), compressedMaps, iteratorTypes,
154143
[&](OpBuilder &b, Location loc, ValueRange args) {
155144
Value result;
156-
if (fullyMaskedRows) {
145+
if (hasMask) {
157146
result = arith::DivFOp::create(b, loc, args[1], args[0]);
158-
Value zero = arith::ConstantOp::create(
159-
b, loc, b.getFloatAttr(args[1].getType(), 0.0));
160-
result = arith::SelectOp::create(b, loc, args[2], zero, result);
147+
Value zero =
148+
arith::ConstantOp::create(b, loc, b.getFloatAttr(f32Type, 0.0));
149+
Value isZero = arith::CmpFOp::create(
150+
b, loc, arith::CmpFPredicate::OEQ, args[0], zero);
151+
result = arith::SelectOp::create(b, loc, isZero, zero, result);
161152
} else {
162153
Value one = arith::ConstantOp::create(
163154
b, loc, b.getFloatAttr(args[0].getType(), 1.0));
@@ -167,7 +158,7 @@ void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp,
167158
result = arith::MulFOp::create(b, loc, reciprocal, args[1]);
168159
}
169160
// Cast result to the required type by attention output.
170-
result = convertScalarToDtype(b, loc, result, args.back().getType(),
161+
result = convertScalarToDtype(b, loc, result, args[2].getType(),
171162
/*isUnsignedCast=*/false);
172163
linalg::YieldOp::create(b, loc, result);
173164
});

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,12 @@ func.func @masked_attention(%q: tensor<2x10x4096x128xf16>, %k: tensor<2x10x4096x
6666

6767
// CHECK-LABEL: func.func @masked_attention
6868
// CHECK-SAME: %[[MASK:.+]]: tensor<2x10x4096x4096xi1>
69-
// Masked: compute the fully-masked row predicate from the mask and use it to
70-
// zero those rows after normalization.
69+
// Masked: keep the fully-masked row guard fused with finalization.
7170
// CHECK: %[[OUT:.+]]:3 = iree_linalg_ext.online_attention
72-
// CHECK: %[[FULLY_MASKED:.+]] = linalg.generic
73-
// CHECK-SAME: ins(%[[MASK]]
74-
// CHECK: arith.xori
75-
// CHECK: arith.andi
76-
// CHECK: linalg.yield
7771
// CHECK: linalg.generic
78-
// CHECK-SAME: ins(%[[OUT]]#2, %[[OUT]]#0, %[[FULLY_MASKED]]
72+
// CHECK-SAME: ins(%[[OUT]]#2, %[[OUT]]#0
7973
// CHECK: arith.divf
74+
// CHECK: arith.cmpf oeq
8075
// CHECK: arith.select
8176
// CHECK: arith.truncf
8277
// CHECK: linalg.yield

compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -139,31 +139,6 @@ Value createFullyMaskedRowsFromScores(OpBuilder &builder, Location loc,
139139
});
140140
}
141141

142-
Value createFullyMaskedRowsFromMask(OpBuilder &builder, Location loc,
143-
AffineMap maskMap, AffineMap rowMap,
144-
ArrayRef<OpFoldResult> rowSizes,
145-
Value mask) {
146-
return createFullyMaskedRows(
147-
builder, loc, maskMap, rowMap, rowSizes, mask,
148-
[](OpBuilder &b, Location loc, Value maskValue) -> Value {
149-
Type maskType = maskValue.getType();
150-
if (maskType.isInteger()) {
151-
if (maskType.getIntOrFloatBitWidth() != 1) {
152-
maskValue =
153-
arith::TruncIOp::create(b, loc, b.getI1Type(), maskValue);
154-
}
155-
Value trueValue =
156-
arith::ConstantOp::create(b, loc, b.getBoolAttr(/*value=*/true));
157-
return arith::XOrIOp::create(b, loc, maskValue, trueValue);
158-
}
159-
Value negInf = arith::ConstantOp::create(
160-
b, loc,
161-
b.getFloatAttr(maskType, -std::numeric_limits<double>::infinity()));
162-
return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OEQ,
163-
maskValue, negInf);
164-
});
165-
}
166-
167142
Value zeroFullyMaskedRows(OpBuilder &builder, Location loc, AffineMap valueMap,
168143
AffineMap rowMap, Value value,
169144
Value fullyMaskedRows) {

compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,6 @@ Value createFullyMaskedRowsFromScores(OpBuilder &builder, Location loc,
5050
ArrayRef<OpFoldResult> rowSizes,
5151
Value scores);
5252

53-
/// Compute a row predicate for safe masked-softmax finalization directly from
54-
/// an attention mask. Integer masks are normalized to i1 and use `false`/0 as
55-
/// masked; floating-point masks use `-inf` as masked.
56-
Value createFullyMaskedRowsFromMask(OpBuilder &builder, Location loc,
57-
AffineMap maskMap, AffineMap rowMap,
58-
ArrayRef<OpFoldResult> rowSizes,
59-
Value mask);
60-
6153
/// Zero every element in rows whose row predicate is true.
6254
Value zeroFullyMaskedRows(OpBuilder &builder, Location loc, AffineMap valueMap,
6355
AffineMap rowMap, Value value, Value fullyMaskedRows);

0 commit comments

Comments
 (0)