@@ -860,6 +860,29 @@ static FailureOr<Value> mulBroadcast(Value val, bool skipCollapseExpand) {
860860 return failure ();
861861}
862862
863+ // Helper to check if a value is a splat constant with all -1 values.
864+ // This handles signless integer types correctly by checking the raw bits.
865+ static bool isConstantMinusOne (Value v) {
866+ Attribute attr;
867+ if (auto cst = v.getDefiningOp <tosa::ConstOp>())
868+ attr = cst.getValuesAttr ();
869+ else if (auto cst = v.getDefiningOp <arith::ConstantOp>())
870+ attr = cst.getValue ();
871+ else
872+ return false ;
873+
874+ auto splatAttr = dyn_cast<SplatElementsAttr>(attr);
875+ if (!splatAttr)
876+ return false ;
877+ auto elemTy = splatAttr.getElementType ();
878+ if (!elemTy.isIntOrIndex ())
879+ return false ;
880+
881+ // Get the splat value as APInt and check if all bits are set (i.e., -1)
882+ APInt val = splatAttr.getSplatValue <APInt>();
883+ return val.isAllOnes ();
884+ }
885+
863886class MatMulConverter final : public OpConversionPattern<tosa::MatMulOp> {
864887public:
865888 using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
@@ -2055,6 +2078,26 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
20552078 assert (succeeded (maybeCurrentSeqLen) && " Must have non-reshape op" );
20562079 Value currentSeqLen = maybeCurrentSeqLen.value ();
20572080
2081+ // Handle pattern: greater(iota, seqLen + (-1)) which is equivalent to
2082+ // iota >= seqLen. This is a common pattern in KV-cache masking where we
2083+ // want to mask positions >= currentSeqLen.
2084+ if (auto addOp = currentSeqLen.getDefiningOp <tosa::AddOp>()) {
2085+ Value input1 = addOp.getInput1 ();
2086+ Value input2 = addOp.getInput2 ();
2087+ // Check if one operand is constant -1
2088+ if (isConstantMinusOne (input2)) {
2089+ // Use input1 as the currentSeqLen candidate
2090+ auto maybeUnwrapped = getValueSkipping (input1, expandAndCollapse);
2091+ if (succeeded (maybeUnwrapped))
2092+ currentSeqLen = maybeUnwrapped.value ();
2093+ } else if (isConstantMinusOne (input1)) {
2094+ // Use input2 as the currentSeqLen candidate
2095+ auto maybeUnwrapped = getValueSkipping (input2, expandAndCollapse);
2096+ if (succeeded (maybeUnwrapped))
2097+ currentSeqLen = maybeUnwrapped.value ();
2098+ }
2099+ }
2100+
20582101 // Verify currentSeqLen is i32 and traces back to a block argument
20592102 if (!isI32BlockArgument (currentSeqLen, seqLenSkip))
20602103 return failure ();
@@ -3385,15 +3428,30 @@ static FailureOr<Value> matchDerefInputPattern(Value derefInput) {
33853428 Value lhsSource = getPreBroadcastSource (lhs);
33863429 Value rhsSource = getPreBroadcastSource (rhs);
33873430
3388- auto lhsType = cast<ShapedType>(lhsSource.getType ());
3389- auto rhsType = cast<ShapedType>(rhsSource.getType ());
3431+ // Helper to trace back through view ops to find the original 3D tensor
3432+ auto traceBackThroughViewOps = [](Value v) -> Value {
3433+ while (Operation *defOp = v.getDefiningOp ()) {
3434+ if (!viewOps.contains (defOp->getName ().getStringRef ()))
3435+ break ;
3436+ // All view ops in our set have a single input
3437+ v = defOp->getOperand (0 );
3438+ }
3439+ return v;
3440+ };
3441+
3442+ // Trace back through view ops to find the original 3D pointer tensor
3443+ Value lhsOriginal = traceBackThroughViewOps (lhsSource);
3444+ Value rhsOriginal = traceBackThroughViewOps (rhsSource);
3445+
3446+ auto lhsOriginalType = cast<ShapedType>(lhsOriginal.getType ());
3447+ auto rhsOriginalType = cast<ShapedType>(rhsOriginal.getType ());
33903448
33913449 // Check which one has last dimension = 1 (pointers)
33923450 // The pointers tensor should have shape [batch, blocks, 1]
3393- if (lhsType .getRank () == 3 && lhsType .getShape ()[2 ] == 1 )
3394- return lhsSource ;
3395- if (rhsType .getRank () == 3 && rhsType .getShape ()[2 ] == 1 )
3396- return rhsSource ;
3451+ if (lhsOriginalType .getRank () == 3 && lhsOriginalType .getShape ()[2 ] == 1 )
3452+ return lhsOriginal ;
3453+ if (rhsOriginalType .getRank () == 3 && rhsOriginalType .getShape ()[2 ] == 1 )
3454+ return rhsOriginal ;
33973455
33983456 return failure ();
33993457}
0 commit comments