Skip to content

Commit d0f20c8

Browse files
Keep masked attention finalization reciprocal form
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
1 parent 25923fa commit d0f20c8

6 files changed

Lines changed: 49 additions & 19 deletions

File tree

compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,8 +476,14 @@ FailureOr<SmallVector<Value>> AttentionOp::decomposeOperation(OpBuilder &b) {
476476

477477
Value fullyMaskedRows;
478478
if (mask != nullptr) {
479-
fullyMaskedRows =
480-
createFullyMaskedRowsFromScores(b, loc, sMap, maxMap, rowRedSize, s);
479+
Type maskElementType = getElementTypeOrSelf(mask.value().getType());
480+
if (isa<IntegerType>(maskElementType)) {
481+
fullyMaskedRows = createFullyMaskedRowsFromMask(
482+
b, loc, *getMaskMap(), maxMap, rowRedSize, mask.value());
483+
} else {
484+
fullyMaskedRows =
485+
createFullyMaskedRowsFromScores(b, loc, sMap, maxMap, rowRedSize, s);
486+
}
481487
}
482488

483489
// max = rowMax(S)

compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,9 @@ func.func @attention_f16_masked(%query: tensor<192x1024x64xf16>,
185185
// CHECK: linalg.generic
186186
// CHECK: arith.addf
187187
// CHECK: linalg.yield
188-
// masked_rows = rowAll(isneginf(S))
188+
// masked_rows = rowAll(!mask)
189189
// CHECK: linalg.generic
190-
// CHECK: arith.cmpf oeq
190+
// CHECK: arith.xori
191191
// CHECK: arith.andi
192192
// CHECK: linalg.yield
193193
// max = rowMax(S)

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -127,35 +127,34 @@ void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp,
127127

128128
// Finalize online attention. With a mask, fully-masked rows can have
129129
// `sum == 0` and `x == 0`; guard that case in the finalization loop to
130-
// produce 0 instead of NaN. Keep this fused with finalization because the
131-
// online-attention lowering path expects a single finalization consumer.
130+
// produce 0 instead of NaN. Keep the normal reciprocal-multiply shape so the
131+
// common path stays as close as possible to unmasked finalization.
132132

133133
// Compress the indexing maps.
134-
SmallVector<AffineMap> compressedMaps =
135-
compressUnusedDims(SmallVector<AffineMap>{sumMap, accMap, accMap});
134+
SmallVector<Value> finalizeInputs = {sum, x};
135+
SmallVector<AffineMap> finalizeMaps = {sumMap, accMap, accMap};
136+
137+
SmallVector<AffineMap> compressedMaps = compressUnusedDims(finalizeMaps);
136138

137139
SmallVector<utils::IteratorType> iteratorTypes(compressedMaps[0].getNumDims(),
138140
utils::IteratorType::parallel);
139141

140142
auto genericOp = linalg::GenericOp::create(
141-
rewriter, loc, attnOp.getOutput().getType(), ValueRange{sum, x},
143+
rewriter, loc, attnOp.getOutput().getType(), finalizeInputs,
142144
attnOp.getOutput(), compressedMaps, iteratorTypes,
143145
[&](OpBuilder &b, Location loc, ValueRange args) {
144-
Value result;
146+
Value one = arith::ConstantOp::create(
147+
b, loc, b.getFloatAttr(args[0].getType(), 1.0));
148+
Value reciprocal = arith::DivFOp::create(b, loc, one, args[0]);
149+
// Both sum and x are in fp32, as created earlier, so we only need to
150+
// cast after the mul.
151+
Value result = arith::MulFOp::create(b, loc, reciprocal, args[1]);
145152
if (hasMask) {
146-
result = arith::DivFOp::create(b, loc, args[1], args[0]);
147153
Value zero =
148154
arith::ConstantOp::create(b, loc, b.getFloatAttr(f32Type, 0.0));
149155
Value isZero = arith::CmpFOp::create(
150156
b, loc, arith::CmpFPredicate::OEQ, args[0], zero);
151157
result = arith::SelectOp::create(b, loc, isZero, zero, result);
152-
} else {
153-
Value one = arith::ConstantOp::create(
154-
b, loc, b.getFloatAttr(args[0].getType(), 1.0));
155-
Value reciprocal = arith::DivFOp::create(b, loc, one, args[0]);
156-
// Both sum and x are in fp32, as created earlier, so we only need to
157-
// cast after the mul.
158-
result = arith::MulFOp::create(b, loc, reciprocal, args[1]);
159158
}
160159
// Cast result to the required type by attention output.
161160
result = convertScalarToDtype(b, loc, result, args[2].getType(),

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,13 @@ func.func @masked_attention(%q: tensor<2x10x4096x128xf16>, %k: tensor<2x10x4096x
6666

6767
// CHECK-LABEL: func.func @masked_attention
6868
// CHECK-SAME: %[[MASK:.+]]: tensor<2x10x4096x4096xi1>
69-
// Masked: keep the fully-masked row guard fused with finalization.
69+
// Masked: keep the same reciprocal-multiply finalization shape and guard the
70+
// fully-masked row case.
7071
// CHECK: %[[OUT:.+]]:3 = iree_linalg_ext.online_attention
7172
// CHECK: linalg.generic
7273
// CHECK-SAME: ins(%[[OUT]]#2, %[[OUT]]#0
7374
// CHECK: arith.divf
75+
// CHECK: arith.mulf
7476
// CHECK: arith.cmpf oeq
7577
// CHECK: arith.select
7678
// CHECK: arith.truncf

compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,22 @@ Value createFullyMaskedRowsFromScores(OpBuilder &builder, Location loc,
139139
});
140140
}
141141

142+
Value createFullyMaskedRowsFromMask(OpBuilder &builder, Location loc,
143+
AffineMap maskMap, AffineMap rowMap,
144+
ArrayRef<OpFoldResult> rowSizes,
145+
Value mask) {
146+
return createFullyMaskedRows(
147+
builder, loc, maskMap, rowMap, rowSizes, mask,
148+
[&](OpBuilder &b, Location loc, Value maskValue) {
149+
if (maskValue.getType().getIntOrFloatBitWidth() != 1) {
150+
maskValue = arith::TruncIOp::create(b, loc, b.getI1Type(), maskValue);
151+
}
152+
Value trueValue =
153+
arith::ConstantOp::create(b, loc, b.getBoolAttr(/*value=*/true));
154+
return arith::XOrIOp::create(b, loc, maskValue, trueValue);
155+
});
156+
}
157+
142158
Value zeroFullyMaskedRows(OpBuilder &builder, Location loc, AffineMap valueMap,
143159
AffineMap rowMap, Value value,
144160
Value fullyMaskedRows) {

compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ Value createFullyMaskedRowsFromScores(OpBuilder &builder, Location loc,
5050
ArrayRef<OpFoldResult> rowSizes,
5151
Value scores);
5252

53+
/// Compute a row predicate for safe masked-softmax finalization by checking
54+
/// whether every integer mask element in a softmax row is false.
55+
Value createFullyMaskedRowsFromMask(OpBuilder &builder, Location loc,
56+
AffineMap maskMap, AffineMap rowMap,
57+
ArrayRef<OpFoldResult> rowSizes,
58+
Value mask);
59+
5360
/// Zero every element in rows whose row predicate is true.
5461
Value zeroFullyMaskedRows(OpBuilder &builder, Location loc, AffineMap valueMap,
5562
AffineMap rowMap, Value value, Value fullyMaskedRows);

0 commit comments

Comments
 (0)