88#include " iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
99#include " iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
1010#include " iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
11- #include " iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
1211#include " mlir/Dialect/Arith/IR/Arith.h"
1312#include " mlir/Dialect/Arith/Utils/Utils.h"
1413#include " mlir/Dialect/Tensor/IR/Tensor.h"
@@ -124,40 +123,32 @@ void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp,
124123
125124 Value x = onlineAttn.getResult (0 );
126125 Value sum = onlineAttn.getResult (2 );
127- Value fullyMaskedRows;
128- if (mask) {
129- fullyMaskedRows = createFullyMaskedRowsFromMask (
130- rewriter, loc, *attnOp.getMaskMap (), sumMap, rowRedSize, *mask);
131- }
126+ bool hasMask = static_cast <bool >(mask);
132127
133128 // Finalize online attention. With a mask, fully-masked rows can have
134- // `sum == 0` and `x == 0`; zero those rows after normalization to produce 0
135- // instead of NaN.
129+ // `sum == 0` and `x == 0`; guard that case in the finalization loop to
130+ // produce 0 instead of NaN. Keep this fused with finalization because the
131+ // online-attention lowering path expects a single finalization consumer.
136132
137133 // Compress the indexing maps.
138- SmallVector<Value> inputs = {sum, x};
139- SmallVector<AffineMap> indexingMapsForFinalization = {sumMap, accMap};
140- if (fullyMaskedRows) {
141- inputs.push_back (fullyMaskedRows);
142- indexingMapsForFinalization.push_back (sumMap);
143- }
144- indexingMapsForFinalization.push_back (accMap);
145134 SmallVector<AffineMap> compressedMaps =
146- compressUnusedDims (indexingMapsForFinalization );
135+ compressUnusedDims (SmallVector<AffineMap>{sumMap, accMap, accMap} );
147136
148137 SmallVector<utils::IteratorType> iteratorTypes (compressedMaps[0 ].getNumDims (),
149138 utils::IteratorType::parallel);
150139
151140 auto genericOp = linalg::GenericOp::create (
152- rewriter, loc, attnOp.getOutput ().getType (), inputs, attnOp. getOutput () ,
153- compressedMaps, iteratorTypes,
141+ rewriter, loc, attnOp.getOutput ().getType (), ValueRange{sum, x} ,
142+ attnOp. getOutput (), compressedMaps, iteratorTypes,
154143 [&](OpBuilder &b, Location loc, ValueRange args) {
155144 Value result;
156- if (fullyMaskedRows ) {
145+ if (hasMask ) {
157146 result = arith::DivFOp::create (b, loc, args[1 ], args[0 ]);
158- Value zero = arith::ConstantOp::create (
159- b, loc, b.getFloatAttr (args[1 ].getType (), 0.0 ));
160- result = arith::SelectOp::create (b, loc, args[2 ], zero, result);
147+ Value zero =
148+ arith::ConstantOp::create (b, loc, b.getFloatAttr (f32Type, 0.0 ));
149+ Value isZero = arith::CmpFOp::create (
150+ b, loc, arith::CmpFPredicate::OEQ, args[0 ], zero);
151+ result = arith::SelectOp::create (b, loc, isZero, zero, result);
161152 } else {
162153 Value one = arith::ConstantOp::create (
163154 b, loc, b.getFloatAttr (args[0 ].getType (), 1.0 ));
@@ -167,7 +158,7 @@ void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp,
167158 result = arith::MulFOp::create (b, loc, reciprocal, args[1 ]);
168159 }
169160 // Cast result to the required type by attention output.
170- result = convertScalarToDtype (b, loc, result, args. back () .getType (),
161+ result = convertScalarToDtype (b, loc, result, args[ 2 ] .getType (),
171162 /* isUnsignedCast=*/ false );
172163 linalg::YieldOp::create (b, loc, result);
173164 });
0 commit comments