@@ -30,6 +30,30 @@ struct ConvertAttentionToOnlineAttentionPass final
3030 void runOnOperation () override ;
3131};
3232
33+ static Value createSafeDenominator (OpBuilder &builder, Location loc,
34+ AffineMap rowMap,
35+ ArrayRef<OpFoldResult> rowSizes, Value sum) {
36+ SmallVector<AffineMap> compressedMaps =
37+ compressUnusedDims (SmallVector<AffineMap>{rowMap, rowMap});
38+ AffineMap inputMap = compressedMaps[0 ];
39+ AffineMap outputMap = compressedMaps[1 ];
40+
41+ Value output = tensor::EmptyOp::create (builder, loc, rowSizes,
42+ getElementTypeOrSelf (sum.getType ()));
43+ SmallVector<utils::IteratorType> iteratorTypes (inputMap.getNumDims (),
44+ utils::IteratorType::parallel);
45+ auto genericOp = linalg::GenericOp::create (
46+ builder, loc, output.getType (), sum, output,
47+ SmallVector<AffineMap>{inputMap, outputMap}, iteratorTypes,
48+ [&](OpBuilder &b, Location loc, ValueRange args) {
49+ Value one = arith::ConstantOp::create (
50+ b, loc, b.getFloatAttr (args[0 ].getType (), 1.0 ));
51+ Value denominator = arith::MaximumFOp::create (b, loc, args[0 ], one);
52+ linalg::YieldOp::create (b, loc, denominator);
53+ });
54+ return genericOp.getResult (0 );
55+ }
56+
3357} // namespace
3458
3559void convertToOnlineAttention (IREE::LinalgExt::AttentionOp attnOp,
@@ -124,39 +148,37 @@ void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp,
124148 Value x = onlineAttn.getResult (0 );
125149 Value sum = onlineAttn.getResult (2 );
126150 bool hasMask = static_cast <bool >(mask);
151+ Value denominator = sum;
152+ if (hasMask) {
153+ denominator = createSafeDenominator (rewriter, loc, sumMap, rowRedSize, sum);
154+ ops.push_back (denominator.getDefiningOp ());
155+ }
127156
128157 // Finalize online attention. With a mask, fully-masked rows can have
129- // `sum == 0` and `x == 0`; guard that case in the finalization loop to
130- // produce 0 instead of NaN. Keep this fused with finalization because the
131- // online-attention lowering path expects a single finalization consumer.
158+ // `sum == 0` and `x == 0`. Rows with at least one finite score have
159+ // `sum >= 1`, so clamp the row denominator to 1 before the finalization loop.
160+ // This preserves the normal rows and produces 0 for fully-masked rows without
161+ // adding a per-output guard.
132162
133163 // Compress the indexing maps.
134- SmallVector<AffineMap> compressedMaps =
135- compressUnusedDims (SmallVector<AffineMap>{sumMap, accMap, accMap});
164+ SmallVector<Value> finalizeInputs = {denominator, x};
165+ SmallVector<AffineMap> finalizeMaps = {sumMap, accMap, accMap};
166+
167+ SmallVector<AffineMap> compressedMaps = compressUnusedDims (finalizeMaps);
136168
137169 SmallVector<utils::IteratorType> iteratorTypes (compressedMaps[0 ].getNumDims (),
138170 utils::IteratorType::parallel);
139171
140172 auto genericOp = linalg::GenericOp::create (
141- rewriter, loc, attnOp.getOutput ().getType (), ValueRange{sum, x} ,
173+ rewriter, loc, attnOp.getOutput ().getType (), finalizeInputs ,
142174 attnOp.getOutput (), compressedMaps, iteratorTypes,
143175 [&](OpBuilder &b, Location loc, ValueRange args) {
144- Value result;
145- if (hasMask) {
146- result = arith::DivFOp::create (b, loc, args[1 ], args[0 ]);
147- Value zero =
148- arith::ConstantOp::create (b, loc, b.getFloatAttr (f32Type, 0.0 ));
149- Value isZero = arith::CmpFOp::create (
150- b, loc, arith::CmpFPredicate::OEQ, args[0 ], zero);
151- result = arith::SelectOp::create (b, loc, isZero, zero, result);
152- } else {
153- Value one = arith::ConstantOp::create (
154- b, loc, b.getFloatAttr (args[0 ].getType (), 1.0 ));
155- Value reciprocal = arith::DivFOp::create (b, loc, one, args[0 ]);
156- // Both sum and x are in fp32, as created earlier, so we only need to
157- // cast after the mul.
158- result = arith::MulFOp::create (b, loc, reciprocal, args[1 ]);
159- }
176+ Value one = arith::ConstantOp::create (
177+ b, loc, b.getFloatAttr (args[0 ].getType (), 1.0 ));
178+ Value reciprocal = arith::DivFOp::create (b, loc, one, args[0 ]);
179+ // Both sum and x are in fp32, as created earlier, so we only need to
180+ // cast after the mul.
181+ Value result = arith::MulFOp::create (b, loc, reciprocal, args[1 ]);
160182 // Cast result to the required type by attention output.
161183 result = convertScalarToDtype (b, loc, result, args[2 ].getType (),
162184 /* isUnsignedCast=*/ false );
0 commit comments