Skip to content

Commit 936478a

Browse files
Use row safe denominator for masked attention
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
1 parent 25923fa commit 936478a

6 files changed

Lines changed: 86 additions & 30 deletions

File tree

compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,8 +476,14 @@ FailureOr<SmallVector<Value>> AttentionOp::decomposeOperation(OpBuilder &b) {
476476

477477
Value fullyMaskedRows;
478478
if (mask != nullptr) {
479-
fullyMaskedRows =
480-
createFullyMaskedRowsFromScores(b, loc, sMap, maxMap, rowRedSize, s);
479+
Type maskElementType = getElementTypeOrSelf(mask.value().getType());
480+
if (isa<IntegerType>(maskElementType)) {
481+
fullyMaskedRows = createFullyMaskedRowsFromMask(
482+
b, loc, *getMaskMap(), maxMap, rowRedSize, mask.value());
483+
} else {
484+
fullyMaskedRows =
485+
createFullyMaskedRowsFromScores(b, loc, sMap, maxMap, rowRedSize, s);
486+
}
481487
}
482488

483489
// max = rowMax(S)

compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,9 @@ func.func @attention_f16_masked(%query: tensor<192x1024x64xf16>,
185185
// CHECK: linalg.generic
186186
// CHECK: arith.addf
187187
// CHECK: linalg.yield
188-
// masked_rows = rowAll(isneginf(S))
188+
// masked_rows = rowAll(!mask)
189189
// CHECK: linalg.generic
190-
// CHECK: arith.cmpf oeq
190+
// CHECK: arith.xori
191191
// CHECK: arith.andi
192192
// CHECK: linalg.yield
193193
// max = rowMax(S)

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,30 @@ struct ConvertAttentionToOnlineAttentionPass final
3030
void runOnOperation() override;
3131
};
3232

33+
static Value createSafeDenominator(OpBuilder &builder, Location loc,
34+
AffineMap rowMap,
35+
ArrayRef<OpFoldResult> rowSizes, Value sum) {
36+
SmallVector<AffineMap> compressedMaps =
37+
compressUnusedDims(SmallVector<AffineMap>{rowMap, rowMap});
38+
AffineMap inputMap = compressedMaps[0];
39+
AffineMap outputMap = compressedMaps[1];
40+
41+
Value output = tensor::EmptyOp::create(builder, loc, rowSizes,
42+
getElementTypeOrSelf(sum.getType()));
43+
SmallVector<utils::IteratorType> iteratorTypes(inputMap.getNumDims(),
44+
utils::IteratorType::parallel);
45+
auto genericOp = linalg::GenericOp::create(
46+
builder, loc, output.getType(), sum, output,
47+
SmallVector<AffineMap>{inputMap, outputMap}, iteratorTypes,
48+
[&](OpBuilder &b, Location loc, ValueRange args) {
49+
Value one = arith::ConstantOp::create(
50+
b, loc, b.getFloatAttr(args[0].getType(), 1.0));
51+
Value denominator = arith::MaximumFOp::create(b, loc, args[0], one);
52+
linalg::YieldOp::create(b, loc, denominator);
53+
});
54+
return genericOp.getResult(0);
55+
}
56+
3357
} // namespace
3458

3559
void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp,
@@ -124,39 +148,37 @@ void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp,
124148
Value x = onlineAttn.getResult(0);
125149
Value sum = onlineAttn.getResult(2);
126150
bool hasMask = static_cast<bool>(mask);
151+
Value denominator = sum;
152+
if (hasMask) {
153+
denominator = createSafeDenominator(rewriter, loc, sumMap, rowRedSize, sum);
154+
ops.push_back(denominator.getDefiningOp());
155+
}
127156

128157
// Finalize online attention. With a mask, fully-masked rows can have
129-
// `sum == 0` and `x == 0`; guard that case in the finalization loop to
130-
// produce 0 instead of NaN. Keep this fused with finalization because the
131-
// online-attention lowering path expects a single finalization consumer.
158+
// `sum == 0` and `x == 0`. Rows with at least one finite score have
159+
// `sum >= 1`, so clamp the row denominator to 1 before the finalization loop.
160+
// This preserves the normal rows and produces 0 for fully-masked rows without
161+
// adding a per-output guard.
132162

133163
// Compress the indexing maps.
134-
SmallVector<AffineMap> compressedMaps =
135-
compressUnusedDims(SmallVector<AffineMap>{sumMap, accMap, accMap});
164+
SmallVector<Value> finalizeInputs = {denominator, x};
165+
SmallVector<AffineMap> finalizeMaps = {sumMap, accMap, accMap};
166+
167+
SmallVector<AffineMap> compressedMaps = compressUnusedDims(finalizeMaps);
136168

137169
SmallVector<utils::IteratorType> iteratorTypes(compressedMaps[0].getNumDims(),
138170
utils::IteratorType::parallel);
139171

140172
auto genericOp = linalg::GenericOp::create(
141-
rewriter, loc, attnOp.getOutput().getType(), ValueRange{sum, x},
173+
rewriter, loc, attnOp.getOutput().getType(), finalizeInputs,
142174
attnOp.getOutput(), compressedMaps, iteratorTypes,
143175
[&](OpBuilder &b, Location loc, ValueRange args) {
144-
Value result;
145-
if (hasMask) {
146-
result = arith::DivFOp::create(b, loc, args[1], args[0]);
147-
Value zero =
148-
arith::ConstantOp::create(b, loc, b.getFloatAttr(f32Type, 0.0));
149-
Value isZero = arith::CmpFOp::create(
150-
b, loc, arith::CmpFPredicate::OEQ, args[0], zero);
151-
result = arith::SelectOp::create(b, loc, isZero, zero, result);
152-
} else {
153-
Value one = arith::ConstantOp::create(
154-
b, loc, b.getFloatAttr(args[0].getType(), 1.0));
155-
Value reciprocal = arith::DivFOp::create(b, loc, one, args[0]);
156-
// Both sum and x are in fp32, as created earlier, so we only need to
157-
// cast after the mul.
158-
result = arith::MulFOp::create(b, loc, reciprocal, args[1]);
159-
}
176+
Value one = arith::ConstantOp::create(
177+
b, loc, b.getFloatAttr(args[0].getType(), 1.0));
178+
Value reciprocal = arith::DivFOp::create(b, loc, one, args[0]);
179+
// Both sum and x are in fp32, as created earlier, so we only need to
180+
// cast after the mul.
181+
Value result = arith::MulFOp::create(b, loc, reciprocal, args[1]);
160182
// Cast result to the required type by attention output.
161183
result = convertScalarToDtype(b, loc, result, args[2].getType(),
162184
/*isUnsignedCast=*/false);

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_online_attention.mlir

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,17 @@ func.func @masked_attention(%q: tensor<2x10x4096x128xf16>, %k: tensor<2x10x4096x
6666

6767
// CHECK-LABEL: func.func @masked_attention
6868
// CHECK-SAME: %[[MASK:.+]]: tensor<2x10x4096x4096xi1>
69-
// Masked: keep the fully-masked row guard fused with finalization.
69+
// Masked: clamp the row denominator once and keep finalization in the same
70+
// reciprocal-multiply shape as the unmasked path.
7071
// CHECK: %[[OUT:.+]]:3 = iree_linalg_ext.online_attention
72+
// CHECK: %[[DENOM:.+]] = linalg.generic
73+
// CHECK-SAME: ins(%[[OUT]]#2
74+
// CHECK: arith.maximumf
75+
// CHECK: linalg.yield
7176
// CHECK: linalg.generic
72-
// CHECK-SAME: ins(%[[OUT]]#2, %[[OUT]]#0
77+
// CHECK-SAME: ins(%[[DENOM]], %[[OUT]]#0
7378
// CHECK: arith.divf
74-
// CHECK: arith.cmpf oeq
75-
// CHECK: arith.select
79+
// CHECK: arith.mulf
80+
// CHECK-NOT: arith.select
7681
// CHECK: arith.truncf
7782
// CHECK: linalg.yield

compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,22 @@ Value createFullyMaskedRowsFromScores(OpBuilder &builder, Location loc,
139139
});
140140
}
141141

142+
Value createFullyMaskedRowsFromMask(OpBuilder &builder, Location loc,
143+
AffineMap maskMap, AffineMap rowMap,
144+
ArrayRef<OpFoldResult> rowSizes,
145+
Value mask) {
146+
return createFullyMaskedRows(
147+
builder, loc, maskMap, rowMap, rowSizes, mask,
148+
[&](OpBuilder &b, Location loc, Value maskValue) {
149+
if (maskValue.getType().getIntOrFloatBitWidth() != 1) {
150+
maskValue = arith::TruncIOp::create(b, loc, b.getI1Type(), maskValue);
151+
}
152+
Value trueValue =
153+
arith::ConstantOp::create(b, loc, b.getBoolAttr(/*value=*/true));
154+
return arith::XOrIOp::create(b, loc, maskValue, trueValue);
155+
});
156+
}
157+
142158
Value zeroFullyMaskedRows(OpBuilder &builder, Location loc, AffineMap valueMap,
143159
AffineMap rowMap, Value value,
144160
Value fullyMaskedRows) {

compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ Value createFullyMaskedRowsFromScores(OpBuilder &builder, Location loc,
5050
ArrayRef<OpFoldResult> rowSizes,
5151
Value scores);
5252

53+
/// Compute a row predicate for safe masked-softmax finalization by checking
54+
/// whether every integer mask element in a softmax row is false.
55+
Value createFullyMaskedRowsFromMask(OpBuilder &builder, Location loc,
56+
AffineMap maskMap, AffineMap rowMap,
57+
ArrayRef<OpFoldResult> rowSizes,
58+
Value mask);
59+
5360
/// Zero every element in rows whose row predicate is true.
5461
Value zeroFullyMaskedRows(OpBuilder &builder, Location loc, AffineMap valueMap,
5562
AffineMap rowMap, Value value, Value fullyMaskedRows);

0 commit comments

Comments
 (0)