Skip to content

Commit eb2b8bf

Browse files
authored
[AIROCMLIR-658] Lower migraphx.quant_dot from Linalg to Rock (#2317)
Supporting scaled GEMM in the migraphx ->linalg and linalg -> rock pipeline. There are essentially two cases that we are going to handle in this PR: quant_dot with no scale parameter, and quant_dot with scale parameters. In the first case, we will be just emitting a `linalg.batch_matmul`. This follows the same path as `migraphx.dot` in that case. In the second case, we emit a `linalg.generic` loop that have matrix A, scaled A, matrix B, scale B as input (see below). The inner loop of linalg.generic will be the product of the 4 inputs. We emit the `quant_dot` attribute so that `linalg-to-rock` will know that this is a scaled GEMM. ```python def my_scaled_gemm_impl(a, b, a_scale, b_scale): batch = a.shape[0] m_dim = a.shape[1] k_dim = a.shape[2] n_dim = b_scale.shape[2] assert k_dim == b.shape[1] result = np.zeros((batch, m_dim, n_dim)) for bb in range(batch): for m in range(m_dim): for n in range(n_dim): for k in range(k_dim): result[bb, m, n] += a[bb, m, k]*a_scale[bb, m, k] * b[bb, k, n]*b_scale[bb, k, n] return result ``` or in mlir form: ```mlir linalg.generic {indexing_maps = [#map, #map, #map1, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%expanded_0, %collapsed, %expanded, %collapsed_4 :...) outs(%2 : tensor<1x64x64xf32>) attrs = {quant_dot = true} ```
1 parent b7bde90 commit eb2b8bf

16 files changed

Lines changed: 482 additions & 52 deletions

mlir/lib/Conversion/LinalgToRock/LinalgToRock.cpp

Lines changed: 107 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,27 @@
1818
#include "mlir/IR/AffineExpr.h"
1919
#include "mlir/IR/PatternMatch.h"
2020

21+
#include <tuple>
22+
2123
using namespace mlir;
2224

2325
namespace {
2426
template <typename LinalgMatOp>
2527
struct MatmulConverter final : public OpConversionPattern<LinalgMatOp> {
28+
struct MatmulContext {
29+
Value aMatrix, bMatrix, scaleA, scaleB;
30+
UnitAttr aTransposedAttr, bTransposedAttr, aScaleTransposedAttr,
31+
bScaleTransposedAttr;
32+
};
33+
2634
using OpConversionPattern<LinalgMatOp>::OpConversionPattern;
2735
using OpConversionPattern<LinalgMatOp>::getTypeConverter;
2836
using OpAdaptor = typename OpConversionPattern<LinalgMatOp>::OpAdaptor;
2937

38+
FailureOr<MatmulContext>
39+
getRockMatmulContext(LinalgMatOp op, OpAdaptor adaptor,
40+
ConversionPatternRewriter &rewriter) const;
41+
3042
LogicalResult
3143
matchAndRewrite(LinalgMatOp op, OpAdaptor adaptor,
3244
ConversionPatternRewriter &rewriter) const override;
@@ -37,12 +49,9 @@ struct MatmulConverter final : public OpConversionPattern<LinalgMatOp> {
3749
/// operandIndex is 0 for A matrix and 1 for B matrix
3850
/// Returns false if identity map, true if last two dims swapped, failure
3951
/// otherwise.
40-
template <typename LinalgOp>
41-
static FailureOr<bool> isMatrixTransposed(LinalgOp op, unsigned operandIndex) {
42-
auto indexingMap =
43-
dyn_cast<AffineMapAttr>(op.getIndexingMaps()[operandIndex]);
44-
if (!indexingMap || (operandIndex != 1 && operandIndex != 0) ||
45-
indexingMap.getAffineMap().getNumResults() < 2) {
52+
static FailureOr<bool> isMatrixTransposed(AffineMapAttr indexingMap,
53+
bool isAMatrix) {
54+
if (!indexingMap || indexingMap.getAffineMap().getNumResults() < 2) {
4655
// it is possible for the result of the affine map to have one dimension in
4756
// the case of broadcasting
4857
return failure();
@@ -79,17 +88,68 @@ static FailureOr<bool> isMatrixTransposed(LinalgOp op, unsigned operandIndex) {
7988
// B matrix (operandIndex=1):
8089
// - Transposed: (d0, d1, d2, d3) -> (d0, d2, d3) i.e., (batch, n, k)
8190
// Last two results map to positions: d2->2, d3->3 (swapped)
82-
unsigned transposedSecond = operandIndex == 0 ? numInputs - 1 : numInputs - 2;
83-
unsigned transposedLast = operandIndex == 0 ? numInputs - 3 : numInputs - 1;
91+
unsigned transposedSecond = isAMatrix ? numInputs - 1 : numInputs - 2;
92+
unsigned transposedLast = isAMatrix ? numInputs - 3 : numInputs - 1;
8493
bool isTransposed = (secondLast.getPosition() == transposedSecond &&
8594
last.getPosition() == transposedLast);
8695
return isTransposed;
8796
}
8897

8998
template <typename LinalgMatOp>
90-
LogicalResult MatmulConverter<LinalgMatOp>::matchAndRewrite(
99+
FailureOr<typename MatmulConverter<LinalgMatOp>::MatmulContext>
100+
MatmulConverter<LinalgMatOp>::getRockMatmulContext(
91101
LinalgMatOp op, OpAdaptor adaptor,
92102
ConversionPatternRewriter &rewriter) const {
103+
// Nice wrapper around isMatrixTransposed to reduce code duplication
104+
auto getTransposeAttrs = [&](AffineMapAttr matrixAIndexingMap,
105+
AffineMapAttr matrixBIndexingMap)
106+
-> FailureOr<std::tuple<UnitAttr, UnitAttr>> {
107+
FailureOr<bool> maybeATransposed =
108+
isMatrixTransposed(matrixAIndexingMap, /*isAMatrix=*/true);
109+
FailureOr<bool> maybeBTransposed =
110+
isMatrixTransposed(matrixBIndexingMap, /*isAMatrix=*/false);
111+
if (failed(maybeATransposed) || failed(maybeBTransposed))
112+
return failure();
113+
UnitAttr aTransposedAttr =
114+
*maybeATransposed ? rewriter.getAttr<UnitAttr>() : nullptr;
115+
UnitAttr bTransposedAttr =
116+
*maybeBTransposed ? rewriter.getAttr<UnitAttr>() : nullptr;
117+
return std::make_tuple(aTransposedAttr, bTransposedAttr);
118+
};
119+
120+
MatmulContext context;
121+
if (isa<linalg::GenericOp>(op) && op->hasAttr("rock.quant_dot") &&
122+
op.getInputs().size() == 4 && op.getOutputs().size() == 1) {
123+
// The linalg.generic op from migraphx-to-linalg place this operand in this
124+
// way.
125+
context.aMatrix = op.getInputs()[0];
126+
context.scaleA = op.getInputs()[1];
127+
context.bMatrix = op.getInputs()[2];
128+
context.scaleB = op.getInputs()[3];
129+
130+
auto maybeTranspose =
131+
getTransposeAttrs(dyn_cast<AffineMapAttr>(op.getIndexingMaps()[0]),
132+
dyn_cast<AffineMapAttr>(op.getIndexingMaps()[2]));
133+
auto maybeScaleTranspose =
134+
getTransposeAttrs(dyn_cast<AffineMapAttr>(op.getIndexingMaps()[1]),
135+
dyn_cast<AffineMapAttr>(op.getIndexingMaps()[3]));
136+
if (failed(maybeTranspose) || failed(maybeScaleTranspose))
137+
return op.emitError("cannot determine if input matrix is transposed");
138+
auto [aTransposedAttr, bTransposedAttr] = *maybeTranspose;
139+
auto [aScaleTransposedAttr, bScaleTransposedAttr] = *maybeScaleTranspose;
140+
141+
context.aTransposedAttr = aTransposedAttr;
142+
context.aScaleTransposedAttr = aScaleTransposedAttr;
143+
context.bTransposedAttr = bTransposedAttr;
144+
context.bScaleTransposedAttr = bScaleTransposedAttr;
145+
return success(context);
146+
}
147+
148+
// only expect either linalg.matmul or linalg.batch_matmul
149+
if (!isa<linalg::MatmulOp, linalg::BatchMatmulOp>(op)) {
150+
return failure();
151+
}
152+
93153
Location loc = op.getLoc();
94154
Value a = op.getOperand(0);
95155
Value b = op.getOperand(1);
@@ -101,34 +161,50 @@ LogicalResult MatmulConverter<LinalgMatOp>::matchAndRewrite(
101161
"expected the output to have RankedTensorType and static shape");
102162
}
103163

104-
RankedTensorType outputType = cast<RankedTensorType>(cOriginal.getType());
105-
Value c = bufferization::AllocTensorOp::create(rewriter, op.getLoc(),
106-
outputType, {});
107-
108-
// Setting the A and B matrix transpose attribute
109-
FailureOr<bool> maybeAMatrixTransposed =
110-
isMatrixTransposed<LinalgMatOp>(op, 0);
111-
FailureOr<bool> maybeBMatrixTransposed =
112-
isMatrixTransposed<LinalgMatOp>(op, 1);
113-
if (failed(maybeAMatrixTransposed) || failed(maybeBMatrixTransposed)) {
164+
auto maybeTranspose =
165+
getTransposeAttrs(dyn_cast<AffineMapAttr>(op.getIndexingMaps()[0]),
166+
dyn_cast<AffineMapAttr>(op.getIndexingMaps()[1]));
167+
if (failed(maybeTranspose))
114168
return op.emitError("cannot determine if input matrix is transposed");
169+
auto [aTransposedAttr, bTransposedAttr] = *maybeTranspose;
170+
171+
context.aMatrix = a;
172+
context.scaleA = nullptr;
173+
context.bMatrix = b;
174+
context.scaleB = nullptr;
175+
context.aTransposedAttr = aTransposedAttr;
176+
context.bTransposedAttr = bTransposedAttr;
177+
return success(context);
178+
}
179+
180+
template <typename LinalgMatOp>
181+
LogicalResult MatmulConverter<LinalgMatOp>::matchAndRewrite(
182+
LinalgMatOp op, OpAdaptor adaptor,
183+
ConversionPatternRewriter &rewriter) const {
184+
Location loc = op.getLoc();
185+
FailureOr<MatmulContext> maybeContext =
186+
getRockMatmulContext(op, adaptor, rewriter);
187+
if (failed(maybeContext)) {
188+
return failure();
115189
}
116-
UnitAttr aTransposedAttr =
117-
(maybeAMatrixTransposed.value()) ? rewriter.getAttr<UnitAttr>() : nullptr;
118-
UnitAttr bTransposedAttr =
119-
(maybeBMatrixTransposed.value()) ? rewriter.getAttr<UnitAttr>() : nullptr;
190+
MatmulContext context = maybeContext.value();
120191

121192
// TODO: handle split K attributes as well
122193
// TODO: handle broadcasting for matrix A and B
123-
// TODO: Scaled GEMM not yet supported (scaleA/scaleB currently null)
194+
RankedTensorType outputType =
195+
cast<RankedTensorType>(op.getOutputs()[0].getType());
124196
rock::StoreMethodAttr method =
125197
rewriter.getAttr<rock::StoreMethodAttr>(rock::StoreMethod::Set);
198+
Value c = bufferization::AllocTensorOp::create(rewriter, op.getLoc(),
199+
outputType, {});
126200
rock::GemmOp result = rock::GemmOp::create(
127-
rewriter, loc, c.getType(), a, b, c, /*scaleA=*/nullptr,
128-
/*scaleB=*/nullptr, /*aTransposed=*/aTransposedAttr,
129-
/*bTransposed=*/bTransposedAttr,
130-
/*cTransposed=*/nullptr, /*aScaleTransposed=*/nullptr,
131-
/*bScaleTransposed=*/nullptr, /*features=*/nullptr,
201+
rewriter, loc, c.getType(), context.aMatrix, context.bMatrix, c,
202+
/*scaleA=*/context.scaleA,
203+
/*scaleB=*/context.scaleB, /*aTransposed=*/context.aTransposedAttr,
204+
/*bTransposed=*/context.bTransposedAttr,
205+
/*cTransposed=*/nullptr,
206+
/*aScaleTransposed=*/context.aScaleTransposedAttr,
207+
/*bScaleTransposed=*/context.bScaleTransposedAttr, /*features=*/nullptr,
132208
/*storeMethod=*/method, /*derivedBlockSize=*/nullptr,
133209
/*gridSize=*/nullptr, /*params=*/nullptr);
134210

@@ -186,6 +262,6 @@ LogicalResult ExpandStrideConverter::matchAndRewrite(
186262
void mlir::rock::populateLinalgToRockConversionPattern(
187263
RewritePatternSet &pattern, MLIRContext *context) {
188264
pattern.add<MatmulConverter<linalg::BatchMatmulOp>,
189-
MatmulConverter<linalg::MatmulOp>, ExpandStrideConverter>(
190-
context);
265+
MatmulConverter<linalg::MatmulOp>, ExpandStrideConverter,
266+
MatmulConverter<linalg::GenericOp>>(context);
191267
}

mlir/lib/Conversion/LinalgToRock/LinalgToRockPass.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ static void populateLinalgToRockDialectConversion(ConversionTarget &target) {
5454
if (!linalgOp) {
5555
return std::nullopt;
5656
}
57+
58+
if (op->hasAttr("rock.quant_dot")) {
59+
return false;
60+
}
61+
5762
return linalg::isElementwise(linalgOp) || isa<linalg::GenericOp>(op) ||
5863
isa<linalg::YieldOp>(op);
5964
});

mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.cpp

Lines changed: 113 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/Dialect/Rock/IR/Rock.h"
2222
#include "mlir/Dialect/Rock/IR/RockTypes.h"
2323
#include "mlir/Dialect/Tensor/IR/Tensor.h"
24+
#include "mlir/IR/BuiltinTypeInterfaces.h"
2425

2526
using namespace mlir;
2627

@@ -606,25 +607,89 @@ ConvConverter::matchAndRewrite(migraphx::ConvolutionOp op, OpAdaptor adaptor,
606607
return emitConv(rewriter, op, input, filter);
607608
}
608609

609-
// TODO: add support for scaled gemms, and migraphx::DeQuantizeLinearConverter
610+
// TODO: migraphx::DeQuantizeLinearConverter
610611
//===----------------------------------------------------------------------===//
611612
// Base kernels (gemm)
612613
//===----------------------------------------------------------------------===//
613614
namespace {
614-
struct DotConverter final : public OpConversionPattern<migraphx::DotOp> {
615-
using OpConversionPattern<migraphx::DotOp>::OpConversionPattern;
616-
using OpConversionPattern<migraphx::DotOp>::getTypeConverter;
617-
using OpAdaptor = typename OpConversionPattern<migraphx::DotOp>::OpAdaptor;
615+
template <typename MIGXDotOp>
616+
struct DotConverter final : public OpConversionPattern<MIGXDotOp> {
617+
using OpConversionPattern<MIGXDotOp>::OpConversionPattern;
618+
using OpConversionPattern<MIGXDotOp>::getTypeConverter;
619+
using OpAdaptor = typename OpConversionPattern<MIGXDotOp>::OpAdaptor;
620+
621+
static_assert(std::is_same_v<MIGXDotOp, migraphx::DotOp> ||
622+
std::is_same_v<MIGXDotOp, migraphx::QuantDotOp>,
623+
"MIGXDotOp must be migraphx::DotOp or migraphx::QuantDotOp");
618624

619625
LogicalResult
620-
matchAndRewrite(migraphx::DotOp op, OpAdaptor adaptor,
626+
matchAndRewrite(MIGXDotOp op, OpAdaptor adaptor,
621627
ConversionPatternRewriter &rewriter) const override;
628+
629+
Value createScaledDotGeneric(OpBuilder &rewriter, Location loc, Value aMatrix,
630+
Value scaleA, Value bMatrix, Value scaleB,
631+
RankedTensorType resultType) const;
622632
};
623633
} // namespace
624634

625-
LogicalResult
626-
DotConverter::matchAndRewrite(migraphx::DotOp op, OpAdaptor adaptor,
627-
ConversionPatternRewriter &rewriter) const {
635+
template <typename MIGXDotOp>
636+
Value DotConverter<MIGXDotOp>::createScaledDotGeneric(
637+
OpBuilder &rewriter, Location loc, Value aMatrix, Value scaleA,
638+
Value bMatrix, Value scaleB, RankedTensorType resultType) const {
639+
auto bodyBuilder = [](OpBuilder &b, Location loc, ValueRange blockArgs) {
640+
assert(blockArgs.size() == 5 && "expected 5 arguments");
641+
642+
SmallVector<Value> inputs =
643+
llvm::map_to_vector(blockArgs.drop_back(1), [&](Value arg) {
644+
if (!arg.getType().isF32()) {
645+
return convertScalarToDtype(b, loc, arg, b.getF32Type(),
646+
/*isUnsignedCast=*/false);
647+
}
648+
return arg;
649+
});
650+
651+
Value result = arith::createProduct(b, loc, inputs);
652+
if (result.getType() != blockArgs[4].getType()) {
653+
result = convertScalarToDtype(b, loc, result, blockArgs[4].getType(),
654+
/*isUnsignedCast=*/false);
655+
}
656+
// Accumulate the result
657+
ArithBuilder arithBuilder(b, loc);
658+
result = arithBuilder.add(result, blockArgs[4]);
659+
linalg::YieldOp::create(b, loc, result);
660+
};
661+
662+
Value zero = arith::ConstantOp::create(rewriter, loc, resultType,
663+
rewriter.getZeroAttr(resultType));
664+
665+
// The input matrix A has dimensions [batch, m, k], and the input matrix B
666+
// has dimensions [batch, k, n]. The output matrix C has dimensions [batch,
667+
// m, n].
668+
AffineExpr batch = getAffineDimExpr(/*position=*/0, rewriter.getContext()),
669+
m = getAffineDimExpr(/*position=*/1, rewriter.getContext()),
670+
n = getAffineDimExpr(/*position=*/2, rewriter.getContext()),
671+
k = getAffineDimExpr(/*position=*/3, rewriter.getContext());
672+
AffineMap aMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0,
673+
{batch, m, k}, rewriter.getContext());
674+
AffineMap bMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0,
675+
{batch, k, n}, rewriter.getContext());
676+
AffineMap cMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0,
677+
{batch, m, n}, rewriter.getContext());
678+
SmallVector<utils::IteratorType> iteratorTypes(3,
679+
utils::IteratorType::parallel);
680+
iteratorTypes.push_back(utils::IteratorType::reduction);
681+
682+
auto genericOp = linalg::GenericOp::create(
683+
rewriter, loc, resultType, {aMatrix, scaleA, bMatrix, scaleB}, {zero},
684+
{aMap, aMap, bMap, bMap, cMap}, iteratorTypes, bodyBuilder);
685+
genericOp->setAttr("rock.quant_dot", rewriter.getBoolAttr(true));
686+
return genericOp->getResult(0);
687+
}
688+
689+
template <typename MIGXDotOp>
690+
LogicalResult DotConverter<MIGXDotOp>::matchAndRewrite(
691+
MIGXDotOp op, OpAdaptor adaptor,
692+
ConversionPatternRewriter &rewriter) const {
628693
Location loc = op->getLoc();
629694
Value inA = adaptor.getInA();
630695
Value inB = adaptor.getInB();
@@ -719,11 +784,43 @@ DotConverter::matchAndRewrite(migraphx::DotOp op, OpAdaptor adaptor,
719784
inB = reshapeToDimThree(rankB, newBType, inB);
720785
}
721786

722-
auto init = arith::ConstantOp::create(rewriter, loc, newOutType,
723-
rewriter.getZeroAttr(newOutType))
724-
.getResult();
725-
Value result = linalg::BatchMatmulOp::create(rewriter, loc, {inA, inB}, init)
726-
.getResult(0);
787+
auto emitLinalgBatchMatmul = [&](Value inA, Value inB,
788+
RankedTensorType newOutType) {
789+
auto init = arith::ConstantOp::create(rewriter, loc, newOutType,
790+
rewriter.getZeroAttr(newOutType))
791+
.getResult();
792+
Value result =
793+
linalg::BatchMatmulOp::create(rewriter, loc, {inA, inB}, init)
794+
.getResult(0);
795+
return result;
796+
};
797+
798+
Value result;
799+
if constexpr (std::is_same_v<MIGXDotOp, migraphx::QuantDotOp>) {
800+
Value scaleA = adaptor.getScaleA();
801+
Value scaleB = adaptor.getScaleB();
802+
assert(((scaleA && scaleB) || (!scaleA && !scaleB)) &&
803+
"Both scaleA and scaleB must be provided or neither.");
804+
bool isScaled = scaleA && scaleB;
805+
if (needToReshape && isScaled) {
806+
// scaleA and scaleB should have the same type as inputA and inputB
807+
RankedTensorType scaleAType =
808+
RankedTensorType::get(cast<ShapedType>(inA.getType()).getShape(),
809+
getElementTypeOrSelf(scaleA.getType()));
810+
RankedTensorType scaleBType =
811+
RankedTensorType::get(cast<ShapedType>(inB.getType()).getShape(),
812+
getElementTypeOrSelf(scaleB.getType()));
813+
scaleA = reshapeToDimThree(rankA, scaleAType, scaleA);
814+
scaleB = reshapeToDimThree(rankB, scaleBType, scaleB);
815+
}
816+
817+
// only emit scaleA and scaleB if they are not null
818+
result = (isScaled) ? createScaledDotGeneric(rewriter, loc, inA, scaleA,
819+
inB, scaleB, newOutType)
820+
: emitLinalgBatchMatmul(inA, inB, newOutType);
821+
} else {
822+
result = emitLinalgBatchMatmul(inA, inB, newOutType);
823+
}
727824

728825
// Convert optional attributes
729826
if (auto attr = (*op).template getAttrOfType<StringAttr>("perf_config"))
@@ -1463,7 +1560,8 @@ LiteralConverter::matchAndRewrite(migraphx::LiteralOp op, OpAdaptor adaptor,
14631560
void mlir::migraphx::populateMIGraphXToLinalgConversionPatterns(
14641561
TypeConverter &converter, RewritePatternSet &patterns) {
14651562
patterns
1466-
.add<DotConverter, ElementwiseConverter<migraphx::AddOp, linalg::AddOp>,
1563+
.add<DotConverter<migraphx::DotOp>, DotConverter<migraphx::QuantDotOp>,
1564+
ElementwiseConverter<migraphx::AddOp, linalg::AddOp>,
14671565
ElementwiseConverter<migraphx::SubOp, linalg::SubOp>,
14681566
ElementwiseConverter<migraphx::MulOp, linalg::MulOp>,
14691567
ElementwiseConverter<migraphx::DivOp, linalg::DivOp>,

0 commit comments

Comments
 (0)