Skip to content

Commit 07583ae

Browse files
committed
Handle new deref patterns from MIGraphX
1 parent 49370de commit 07583ae

1 file changed

Lines changed: 64 additions & 6 deletions

File tree

mlir/lib/Conversion/TosaToRock/TosaToRock.cpp

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,29 @@ static FailureOr<Value> mulBroadcast(Value val, bool skipCollapseExpand) {
860860
return failure();
861861
}
862862

863+
// Helper to check if a value is a splat constant with all -1 values.
864+
// This handles signless integer types correctly by checking the raw bits.
865+
static bool isConstantMinusOne(Value v) {
866+
Attribute attr;
867+
if (auto cst = v.getDefiningOp<tosa::ConstOp>())
868+
attr = cst.getValuesAttr();
869+
else if (auto cst = v.getDefiningOp<arith::ConstantOp>())
870+
attr = cst.getValue();
871+
else
872+
return false;
873+
874+
auto splatAttr = dyn_cast<SplatElementsAttr>(attr);
875+
if (!splatAttr)
876+
return false;
877+
auto elemTy = splatAttr.getElementType();
878+
if (!elemTy.isIntOrIndex())
879+
return false;
880+
881+
// Get the splat value as APInt and check if all bits are set (i.e., -1)
882+
APInt val = splatAttr.getSplatValue<APInt>();
883+
return val.isAllOnes();
884+
}
885+
863886
class MatMulConverter final : public OpConversionPattern<tosa::MatMulOp> {
864887
public:
865888
using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
@@ -2055,6 +2078,26 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
20552078
assert(succeeded(maybeCurrentSeqLen) && "Must have non-reshape op");
20562079
Value currentSeqLen = maybeCurrentSeqLen.value();
20572080

2081+
// Handle pattern: greater(iota, seqLen + (-1)) which is equivalent to
2082+
// iota >= seqLen. This is a common pattern in KV-cache masking where we
2083+
// want to mask positions >= currentSeqLen.
2084+
if (auto addOp = currentSeqLen.getDefiningOp<tosa::AddOp>()) {
2085+
Value input1 = addOp.getInput1();
2086+
Value input2 = addOp.getInput2();
2087+
// Check if one operand is constant -1
2088+
if (isConstantMinusOne(input2)) {
2089+
// Use input1 as the currentSeqLen candidate
2090+
auto maybeUnwrapped = getValueSkipping(input1, expandAndCollapse);
2091+
if (succeeded(maybeUnwrapped))
2092+
currentSeqLen = maybeUnwrapped.value();
2093+
} else if (isConstantMinusOne(input1)) {
2094+
// Use input2 as the currentSeqLen candidate
2095+
auto maybeUnwrapped = getValueSkipping(input2, expandAndCollapse);
2096+
if (succeeded(maybeUnwrapped))
2097+
currentSeqLen = maybeUnwrapped.value();
2098+
}
2099+
}
2100+
20582101
// Verify currentSeqLen is i32 and traces back to a block argument
20592102
if (!isI32BlockArgument(currentSeqLen, seqLenSkip))
20602103
return failure();
@@ -3385,15 +3428,30 @@ static FailureOr<Value> matchDerefInputPattern(Value derefInput) {
33853428
Value lhsSource = getPreBroadcastSource(lhs);
33863429
Value rhsSource = getPreBroadcastSource(rhs);
33873430

3388-
auto lhsType = cast<ShapedType>(lhsSource.getType());
3389-
auto rhsType = cast<ShapedType>(rhsSource.getType());
3431+
// Helper to trace back through view ops to find the original 3D tensor
3432+
auto traceBackThroughViewOps = [](Value v) -> Value {
3433+
while (Operation *defOp = v.getDefiningOp()) {
3434+
if (!viewOps.contains(defOp->getName().getStringRef()))
3435+
break;
3436+
// All view ops in our set have a single input
3437+
v = defOp->getOperand(0);
3438+
}
3439+
return v;
3440+
};
3441+
3442+
// Trace back through view ops to find the original 3D pointer tensor
3443+
Value lhsOriginal = traceBackThroughViewOps(lhsSource);
3444+
Value rhsOriginal = traceBackThroughViewOps(rhsSource);
3445+
3446+
auto lhsOriginalType = cast<ShapedType>(lhsOriginal.getType());
3447+
auto rhsOriginalType = cast<ShapedType>(rhsOriginal.getType());
33903448

33913449
// Check which one has last dimension = 1 (pointers)
33923450
// The pointers tensor should have shape [batch, blocks, 1]
3393-
if (lhsType.getRank() == 3 && lhsType.getShape()[2] == 1)
3394-
return lhsSource;
3395-
if (rhsType.getRank() == 3 && rhsType.getShape()[2] == 1)
3396-
return rhsSource;
3451+
if (lhsOriginalType.getRank() == 3 && lhsOriginalType.getShape()[2] == 1)
3452+
return lhsOriginal;
3453+
if (rhsOriginalType.getRank() == 3 && rhsOriginalType.getShape()[2] == 1)
3454+
return rhsOriginal;
33973455

33983456
return failure();
33993457
}

0 commit comments

Comments
 (0)