1818#include " mlir/IR/AffineExpr.h"
1919#include " mlir/IR/PatternMatch.h"
2020
21+ #include < tuple>
22+
2123using namespace mlir ;
2224
2325namespace {
2426template <typename LinalgMatOp>
2527struct MatmulConverter final : public OpConversionPattern<LinalgMatOp> {
28+ struct MatmulContext {
29+ Value aMatrix, bMatrix, scaleA, scaleB;
30+ UnitAttr aTransposedAttr, bTransposedAttr, aScaleTransposedAttr,
31+ bScaleTransposedAttr;
32+ };
33+
2634 using OpConversionPattern<LinalgMatOp>::OpConversionPattern;
2735 using OpConversionPattern<LinalgMatOp>::getTypeConverter;
2836 using OpAdaptor = typename OpConversionPattern<LinalgMatOp>::OpAdaptor;
2937
38+ FailureOr<MatmulContext>
39+ getRockMatmulContext (LinalgMatOp op, OpAdaptor adaptor,
40+ ConversionPatternRewriter &rewriter) const ;
41+
3042 LogicalResult
3143 matchAndRewrite (LinalgMatOp op, OpAdaptor adaptor,
3244 ConversionPatternRewriter &rewriter) const override ;
@@ -37,12 +49,9 @@ struct MatmulConverter final : public OpConversionPattern<LinalgMatOp> {
3749// / operandIndex is 0 for A matrix and 1 for B matrix
3850// / Returns false if identity map, true if last two dims swapped, failure
3951// / otherwise.
40- template <typename LinalgOp>
41- static FailureOr<bool > isMatrixTransposed (LinalgOp op, unsigned operandIndex) {
42- auto indexingMap =
43- dyn_cast<AffineMapAttr>(op.getIndexingMaps ()[operandIndex]);
44- if (!indexingMap || (operandIndex != 1 && operandIndex != 0 ) ||
45- indexingMap.getAffineMap ().getNumResults () < 2 ) {
52+ static FailureOr<bool > isMatrixTransposed (AffineMapAttr indexingMap,
53+ bool isAMatrix) {
54+ if (!indexingMap || indexingMap.getAffineMap ().getNumResults () < 2 ) {
4655 // it is possible for the result of the affine map to have one dimension in
4756 // the case of broadcasting
4857 return failure ();
@@ -79,17 +88,68 @@ static FailureOr<bool> isMatrixTransposed(LinalgOp op, unsigned operandIndex) {
7988 // B matrix (operandIndex=1):
8089 // - Transposed: (d0, d1, d2, d3) -> (d0, d2, d3) i.e., (batch, n, k)
8190 // Last two results map to positions: d2->2, d3->3 (swapped)
82- unsigned transposedSecond = operandIndex == 0 ? numInputs - 1 : numInputs - 2 ;
83- unsigned transposedLast = operandIndex == 0 ? numInputs - 3 : numInputs - 1 ;
91+ unsigned transposedSecond = isAMatrix ? numInputs - 1 : numInputs - 2 ;
92+ unsigned transposedLast = isAMatrix ? numInputs - 3 : numInputs - 1 ;
8493 bool isTransposed = (secondLast.getPosition () == transposedSecond &&
8594 last.getPosition () == transposedLast);
8695 return isTransposed;
8796}
8897
8998template <typename LinalgMatOp>
90- LogicalResult MatmulConverter<LinalgMatOp>::matchAndRewrite(
99+ FailureOr<typename MatmulConverter<LinalgMatOp>::MatmulContext>
100+ MatmulConverter<LinalgMatOp>::getRockMatmulContext(
91101 LinalgMatOp op, OpAdaptor adaptor,
92102 ConversionPatternRewriter &rewriter) const {
103+ // Nice wrapper around isMatrixTransposed to reduce code duplication
104+ auto getTransposeAttrs = [&](AffineMapAttr matrixAIndexingMap,
105+ AffineMapAttr matrixBIndexingMap)
106+ -> FailureOr<std::tuple<UnitAttr, UnitAttr>> {
107+ FailureOr<bool > maybeATransposed =
108+ isMatrixTransposed (matrixAIndexingMap, /* isAMatrix=*/ true );
109+ FailureOr<bool > maybeBTransposed =
110+ isMatrixTransposed (matrixBIndexingMap, /* isAMatrix=*/ false );
111+ if (failed (maybeATransposed) || failed (maybeBTransposed))
112+ return failure ();
113+ UnitAttr aTransposedAttr =
114+ *maybeATransposed ? rewriter.getAttr <UnitAttr>() : nullptr ;
115+ UnitAttr bTransposedAttr =
116+ *maybeBTransposed ? rewriter.getAttr <UnitAttr>() : nullptr ;
117+ return std::make_tuple (aTransposedAttr, bTransposedAttr);
118+ };
119+
120+ MatmulContext context;
121+ if (isa<linalg::GenericOp>(op) && op->hasAttr (" rock.quant_dot" ) &&
122+ op.getInputs ().size () == 4 && op.getOutputs ().size () == 1 ) {
123+ // The linalg.generic op from migraphx-to-linalg place this operand in this
124+ // way.
125+ context.aMatrix = op.getInputs ()[0 ];
126+ context.scaleA = op.getInputs ()[1 ];
127+ context.bMatrix = op.getInputs ()[2 ];
128+ context.scaleB = op.getInputs ()[3 ];
129+
130+ auto maybeTranspose =
131+ getTransposeAttrs (dyn_cast<AffineMapAttr>(op.getIndexingMaps ()[0 ]),
132+ dyn_cast<AffineMapAttr>(op.getIndexingMaps ()[2 ]));
133+ auto maybeScaleTranspose =
134+ getTransposeAttrs (dyn_cast<AffineMapAttr>(op.getIndexingMaps ()[1 ]),
135+ dyn_cast<AffineMapAttr>(op.getIndexingMaps ()[3 ]));
136+ if (failed (maybeTranspose) || failed (maybeScaleTranspose))
137+ return op.emitError (" cannot determine if input matrix is transposed" );
138+ auto [aTransposedAttr, bTransposedAttr] = *maybeTranspose;
139+ auto [aScaleTransposedAttr, bScaleTransposedAttr] = *maybeScaleTranspose;
140+
141+ context.aTransposedAttr = aTransposedAttr;
142+ context.aScaleTransposedAttr = aScaleTransposedAttr;
143+ context.bTransposedAttr = bTransposedAttr;
144+ context.bScaleTransposedAttr = bScaleTransposedAttr;
145+ return success (context);
146+ }
147+
148+ // only expect either linalg.matmul or linalg.batch_matmul
149+ if (!isa<linalg::MatmulOp, linalg::BatchMatmulOp>(op)) {
150+ return failure ();
151+ }
152+
93153 Location loc = op.getLoc ();
94154 Value a = op.getOperand (0 );
95155 Value b = op.getOperand (1 );
@@ -101,34 +161,50 @@ LogicalResult MatmulConverter<LinalgMatOp>::matchAndRewrite(
101161 " expected the output to have RankedTensorType and static shape" );
102162 }
103163
104- RankedTensorType outputType = cast<RankedTensorType>(cOriginal.getType ());
105- Value c = bufferization::AllocTensorOp::create (rewriter, op.getLoc (),
106- outputType, {});
107-
108- // Setting the A and B matrix transpose attribute
109- FailureOr<bool > maybeAMatrixTransposed =
110- isMatrixTransposed<LinalgMatOp>(op, 0 );
111- FailureOr<bool > maybeBMatrixTransposed =
112- isMatrixTransposed<LinalgMatOp>(op, 1 );
113- if (failed (maybeAMatrixTransposed) || failed (maybeBMatrixTransposed)) {
164+ auto maybeTranspose =
165+ getTransposeAttrs (dyn_cast<AffineMapAttr>(op.getIndexingMaps ()[0 ]),
166+ dyn_cast<AffineMapAttr>(op.getIndexingMaps ()[1 ]));
167+ if (failed (maybeTranspose))
114168 return op.emitError (" cannot determine if input matrix is transposed" );
169+ auto [aTransposedAttr, bTransposedAttr] = *maybeTranspose;
170+
171+ context.aMatrix = a;
172+ context.scaleA = nullptr ;
173+ context.bMatrix = b;
174+ context.scaleB = nullptr ;
175+ context.aTransposedAttr = aTransposedAttr;
176+ context.bTransposedAttr = bTransposedAttr;
177+ return success (context);
178+ }
179+
180+ template <typename LinalgMatOp>
181+ LogicalResult MatmulConverter<LinalgMatOp>::matchAndRewrite(
182+ LinalgMatOp op, OpAdaptor adaptor,
183+ ConversionPatternRewriter &rewriter) const {
184+ Location loc = op.getLoc ();
185+ FailureOr<MatmulContext> maybeContext =
186+ getRockMatmulContext (op, adaptor, rewriter);
187+ if (failed (maybeContext)) {
188+ return failure ();
115189 }
116- UnitAttr aTransposedAttr =
117- (maybeAMatrixTransposed.value ()) ? rewriter.getAttr <UnitAttr>() : nullptr ;
118- UnitAttr bTransposedAttr =
119- (maybeBMatrixTransposed.value ()) ? rewriter.getAttr <UnitAttr>() : nullptr ;
190+ MatmulContext context = maybeContext.value ();
120191
121192 // TODO: handle split K attributes as well
122193 // TODO: handle broadcasting for matrix A and B
123- // TODO: Scaled GEMM not yet supported (scaleA/scaleB currently null)
194+ RankedTensorType outputType =
195+ cast<RankedTensorType>(op.getOutputs ()[0 ].getType ());
124196 rock::StoreMethodAttr method =
125197 rewriter.getAttr <rock::StoreMethodAttr>(rock::StoreMethod::Set);
198+ Value c = bufferization::AllocTensorOp::create (rewriter, op.getLoc (),
199+ outputType, {});
126200 rock::GemmOp result = rock::GemmOp::create (
127- rewriter, loc, c.getType (), a, b, c, /* scaleA=*/ nullptr ,
128- /* scaleB=*/ nullptr , /* aTransposed=*/ aTransposedAttr,
129- /* bTransposed=*/ bTransposedAttr,
130- /* cTransposed=*/ nullptr , /* aScaleTransposed=*/ nullptr ,
131- /* bScaleTransposed=*/ nullptr , /* features=*/ nullptr ,
201+ rewriter, loc, c.getType (), context.aMatrix , context.bMatrix , c,
202+ /* scaleA=*/ context.scaleA ,
203+ /* scaleB=*/ context.scaleB , /* aTransposed=*/ context.aTransposedAttr ,
204+ /* bTransposed=*/ context.bTransposedAttr ,
205+ /* cTransposed=*/ nullptr ,
206+ /* aScaleTransposed=*/ context.aScaleTransposedAttr ,
207+ /* bScaleTransposed=*/ context.bScaleTransposedAttr , /* features=*/ nullptr ,
132208 /* storeMethod=*/ method, /* derivedBlockSize=*/ nullptr ,
133209 /* gridSize=*/ nullptr , /* params=*/ nullptr );
134210
@@ -186,6 +262,6 @@ LogicalResult ExpandStrideConverter::matchAndRewrite(
186262void mlir::rock::populateLinalgToRockConversionPattern (
187263 RewritePatternSet &pattern, MLIRContext *context) {
188264 pattern.add <MatmulConverter<linalg::BatchMatmulOp>,
189- MatmulConverter<linalg::MatmulOp>, ExpandStrideConverter>(
190- context);
265+ MatmulConverter<linalg::MatmulOp>, ExpandStrideConverter,
266+ MatmulConverter<linalg::GenericOp>>( context);
191267}
0 commit comments