Skip to content

Commit c7445a2

Browse files
authored
[AIROCMLIR-445] Lower linalg.generic convolution into rock (#2252)
The conversion primarily depends on the attribute in linalg.generic to extract the padding, dilation, stride, and op code. We transfer those attribute to `rock.conv`. The conversion works as the following: 1. `isConv` function reads the stride, dilation, and padding from the op attributes and transform it into the equivalent rock.conv attribute 2. Padding (tensor.padding) is now removed since rock.conv handles that internally. 3. Layout is then transform into the layout expected by the `rock.conv`. 4. Emit attribute for `rock.conv` For more context, this PR tries to match whatever we are currently doing in the migraphx -> tosa -> rock pipeline.
1 parent eb2b8bf commit c7445a2

13 files changed

Lines changed: 1169 additions & 6 deletions

mlir/include/mlir/Dialect/Rock/IR/Rock.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ TransformMapAttr getTransformMapAttrChecked(
9797
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
9898
MLIRContext *context, ArrayRef<TransformAttr> ops, AffineMapAttr map,
9999
DenseI64ArrayAttr upperBounds, DenseI64ArrayAttr lowerBounds);
100+
101+
// Attributes used in the linalg.generic operations so that linalg -> rock is
102+
// easier to match
103+
constexpr llvm::StringLiteral linalgConvOpAttrName = "conv_op";
100104
} // namespace rock
101105
} // namespace mlir
102106
#endif // MLIR_ROCKOPS_OPS_H_

mlir/lib/Conversion/LinalgToRock/LinalgToRock.cpp

Lines changed: 267 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/Func/IR/FuncOps.h"
1515
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1616
#include "mlir/Dialect/Rock/IR/Rock.h"
17+
#include "mlir/Dialect/Rock/IR/TransformMapBuilder.h"
1718
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1819
#include "mlir/IR/AffineExpr.h"
1920
#include "mlir/IR/PatternMatch.h"
@@ -189,7 +190,7 @@ LogicalResult MatmulConverter<LinalgMatOp>::matchAndRewrite(
189190
}
190191
MatmulContext context = maybeContext.value();
191192

192-
// TODO: handle split K attributes as well
193+
// TODO: see (AIROCMLIR-696)
193194
// TODO: handle broadcasting for matrix A and B
194195
RankedTensorType outputType =
195196
cast<RankedTensorType>(op.getOutputs()[0].getType());
@@ -259,9 +260,273 @@ LogicalResult ExpandStrideConverter::matchAndRewrite(
259260
return success();
260261
}
261262

263+
//===----------------------------------------------------------------------===//
264+
// ConvLinalgConverter: linalg.generic (conv) -> rock.conv
265+
//===----------------------------------------------------------------------===//
266+
namespace {
267+
struct ConvFields {
268+
rock::LinalgConvType type;
269+
int64_t spatialDim;
270+
ArrayAttr padding, stride, dilation;
271+
StringAttr perfConfig;
272+
};
273+
} // namespace
274+
275+
static int64_t getSpatialDim(rock::LinalgConvType type) {
276+
switch (type) {
277+
case rock::LinalgConvType::Conv1dNgchGkch:
278+
return 1;
279+
case rock::LinalgConvType::Conv2dNgchwGkchw:
280+
return 2;
281+
case rock::LinalgConvType::Conv3dNgchwdGkchwd:
282+
return 3;
283+
}
284+
llvm_unreachable("unknown LinalgConvType");
285+
}
286+
287+
/// Set filter_layout, input_layout, and output_layout on a rock.conv op.
288+
/// Layouts match the linalg convention: GKC*, NGC*, NGK*.
289+
static void setConvLayoutAttrs(OpBuilder &builder, rock::ConvOp cop,
290+
int64_t spatialDim) {
291+
auto *ctx = builder.getContext();
292+
auto setLayout = [&](StringRef attrName, ArrayRef<StringRef> prefix,
293+
StringRef suffix) {
294+
SmallVector<Attribute> layout;
295+
for (StringRef dim : prefix)
296+
layout.push_back(StringAttr::get(ctx, dim));
297+
for (int64_t i = 0; i < spatialDim; ++i)
298+
layout.push_back(StringAttr::get(ctx, Twine(i) + suffix));
299+
cop->setAttr(attrName, builder.getArrayAttr(layout));
300+
};
301+
302+
// The input layout in the operand of linalg.generic is NGC*, and
303+
// the filter layout is GKC*. We have to transfer these attribute
304+
// because later on in the pass, ConvToGemm expect them to be attached.
305+
setLayout("filter_layout", {"g", "k", "c"}, "");
306+
setLayout("input_layout", {"ni", "gi", "ci"}, "i");
307+
setLayout("output_layout", {"no", "go", "ko"}, "o");
308+
}
309+
310+
/// Remove the tensor.pad + tensor.expand_shape pattern emitted by
311+
/// migraphx-to-linalg, replacing it with just tensor.expand_shape on the
312+
/// unpadded source. rock.conv handles padding internally.
313+
///
314+
/// Expected IR structure:
315+
/// %padded = tensor.pad %original ...
316+
/// %expanded = tensor.expand_shape %padded ...
317+
/// Replaced with:
318+
/// %expanded = tensor.expand_shape %original ...
319+
static FailureOr<Value>
320+
removePaddingFromInput(ConversionPatternRewriter &rewriter,
321+
linalg::GenericOp op, Value in, ArrayAttr padding) {
322+
bool hasPadding = llvm::any_of(padding.getValue(), [](Attribute attr) {
323+
return cast<IntegerAttr>(attr).getInt() != 0;
324+
});
325+
if (!hasPadding)
326+
return in;
327+
328+
auto expanded = in.getDefiningOp<tensor::ExpandShapeOp>();
329+
auto padded = (expanded != nullptr)
330+
? expanded->getOperand(0).getDefiningOp<tensor::PadOp>()
331+
: nullptr;
332+
// We require padding here to have one use because the code structure emitted
333+
// by the MIGraphX -> Linalg have one use. In theory, you don't need this
334+
// check, but better be safe than sorry. This goes with expanded as well
335+
if (!padded || !padded->hasOneUse()) {
336+
op.emitError("unexpected padding code structure");
337+
return failure();
338+
}
339+
340+
if (!expanded || !expanded->hasOneUse()) {
341+
return op.emitError("unexpected group expansion shape code structure");
342+
}
343+
344+
SmallVector<int64_t> resultShape(expanded.getResultType().getShape());
345+
// The tensor.pad operand has no group dimension: [N, G*C, spatial...].
346+
// The expanded result has [N, G, C, spatial_padded...]. Take the first 3
347+
// dims (N, G, C) from the expanded shape and append the unpadded spatial
348+
// dims directly from the pad source starting at position 2.
349+
auto padSourceShape =
350+
cast<RankedTensorType>(padded.getOperand(0).getType()).getShape();
351+
resultShape.resize(3);
352+
resultShape.insert(resultShape.begin() + 3, padSourceShape.begin() + 2,
353+
padSourceShape.end());
354+
355+
RankedTensorType newResultType = RankedTensorType::get(
356+
resultShape, padded.getResultType().getElementType());
357+
Value result = tensor::ExpandShapeOp::create(
358+
rewriter, expanded.getLoc(), newResultType, padded.getOperand(0),
359+
expanded.getReassociationIndices());
360+
// erase the operations as well
361+
rewriter.eraseOp(expanded);
362+
rewriter.eraseOp(padded);
363+
return result;
364+
}
365+
366+
namespace {
367+
struct ConvLinalgConverter final
368+
: public OpConversionPattern<linalg::GenericOp> {
369+
using OpConversionPattern<linalg::GenericOp>::OpConversionPattern;
370+
using OpConversionPattern<linalg::GenericOp>::getTypeConverter;
371+
using OpAdaptor = typename OpConversionPattern<linalg::GenericOp>::OpAdaptor;
372+
373+
LogicalResult
374+
matchAndRewrite(linalg::GenericOp op, OpAdaptor adaptor,
375+
ConversionPatternRewriter &rewriter) const override;
376+
377+
private:
378+
FailureOr<ConvFields> isConv(ConversionPatternRewriter &rewriter,
379+
linalg::GenericOp op) const;
380+
};
381+
} // namespace
382+
383+
FailureOr<ConvFields>
384+
ConvLinalgConverter::isConv(ConversionPatternRewriter &rewriter,
385+
linalg::GenericOp op) const {
386+
auto name =
387+
op->getAttrOfType<rock::LinalgConvTypeAttr>(rock::linalgConvOpAttrName);
388+
if (!name)
389+
return failure();
390+
rock::LinalgConvType convType = name.getValue();
391+
int64_t spatialDim = getSpatialDim(convType);
392+
// Conv1D is broadcasted into Conv2D. To check for error, we
393+
// use effectiveDim instead because it one has more stride/dilation
394+
// in the expanded dimension
395+
int64_t effectiveDim = (spatialDim == 1) ? spatialDim + 1 : spatialDim;
396+
397+
auto convertToArrayAttr =
398+
[&](Attribute arr, ArrayRef<int64_t> dimOneDefaults = {}) -> ArrayAttr {
399+
if (!arr || !isa<ArrayAttr>(arr)) {
400+
return ArrayAttr{};
401+
}
402+
403+
SmallVector<int64_t, 4> values;
404+
llvm::transform(
405+
cast<ArrayAttr>(arr).getValue(), std::back_inserter(values),
406+
[](Attribute val) { return cast<IntegerAttr>(val).getInt(); });
407+
// Conv1D is expanded into Conv2D: append identity defaults for the
408+
// extra spatial dimension (stride=1, dilation=1, pad=0).
409+
if (spatialDim == 1)
410+
values.insert(values.end(), dimOneDefaults.begin(), dimOneDefaults.end());
411+
return rewriter.getIndexArrayAttr(values);
412+
};
413+
414+
auto dilation =
415+
convertToArrayAttr(op->getAttr("dilation"), /*dimOneDefaults=*/{1});
416+
auto stride =
417+
convertToArrayAttr(op->getAttr("stride"), /*dimOneDefaults=*/{1});
418+
if (!dilation || static_cast<int64_t>(dilation.size()) != effectiveDim) {
419+
op.emitError("invalid dilation");
420+
return failure();
421+
}
422+
423+
if (!stride || static_cast<int64_t>(stride.size()) != effectiveDim) {
424+
op.emitError("invalid stride");
425+
return failure();
426+
}
427+
428+
// Input format: [dim0_low, dim1_low, ..., dim0_high, dim1_high, ...]
429+
// Rock format: [dim0_low, dim0_high, dim1_low, dim1_high, ...]
430+
auto originalPadding = convertToArrayAttr(op->getAttr("pad"));
431+
if (!originalPadding) {
432+
op.emitError("no padding found");
433+
return failure();
434+
}
435+
int64_t numSpatial = originalPadding.size() / 2;
436+
SmallVector<Attribute, 8> interleavedPad;
437+
for (int64_t i = 0; i < numSpatial; ++i) {
438+
interleavedPad.push_back(originalPadding[i]);
439+
interleavedPad.push_back(originalPadding[numSpatial + i]);
440+
}
441+
// Conv1D is expanded into Conv2D
442+
if (spatialDim == 1) {
443+
interleavedPad.push_back(rewriter.getIndexAttr(0));
444+
interleavedPad.push_back(rewriter.getIndexAttr(0));
445+
}
446+
auto padding = rewriter.getArrayAttr(interleavedPad);
447+
// note that Conv1D is expanded into Conv2D
448+
if (effectiveDim * 2 != (int64_t)padding.size()) {
449+
op.emitError("invalid number of padding");
450+
return failure();
451+
}
452+
453+
StringAttr perfConfig = op->getAttrOfType<StringAttr>("perf_config");
454+
return ConvFields{convType, spatialDim, padding,
455+
stride, dilation, perfConfig};
456+
}
457+
458+
LogicalResult ConvLinalgConverter::matchAndRewrite(
459+
linalg::GenericOp op, OpAdaptor adaptor,
460+
ConversionPatternRewriter &rewriter) const {
461+
FailureOr<ConvFields> maybeConv = isConv(rewriter, op);
462+
if (failed(maybeConv))
463+
return failure();
464+
465+
ConvFields conv = *maybeConv;
466+
Location loc = op.getLoc();
467+
468+
auto maybeInput =
469+
removePaddingFromInput(rewriter, op, op.getOperand(0), conv.padding);
470+
if (failed(maybeInput))
471+
return failure();
472+
473+
Value input = *maybeInput;
474+
Value filter = op.getOperand(1);
475+
476+
// Conv1D is expanded into Conv2D: unmerge the single spatial dim
477+
// into (spatial, W=1) for filter and input.
478+
int64_t effectiveSpatialDim = conv.spatialDim;
479+
if (conv.spatialDim == 1) {
480+
effectiveSpatialDim = 2;
481+
auto filterShape = cast<RankedTensorType>(filter.getType()).getShape();
482+
rock::BottomUpTMBuilder builder(rewriter, {"g", "k", "c", "0"}, filterShape,
483+
loc);
484+
builder.passThrough({"gf", "kf", "cf"}, {0, 1, 2}, {"g", "k", "c"});
485+
builder.unmerge({"0f", "1f"}, {3, 4}, "0", {filterShape[3], 1});
486+
filter = rock::TransformOp::create(rewriter, loc, filter, builder.get());
487+
488+
auto inputShape = cast<RankedTensorType>(input.getType()).getShape();
489+
rock::BottomUpTMBuilder b(rewriter, {"n", "g", "c", "0"}, inputShape, loc);
490+
b.passThrough({"nu", "gu", "cu"}, {0, 1, 2}, {"n", "g", "c"});
491+
b.unmerge({"0u", "1u"}, {3, 4}, "0", {inputShape[3], 1});
492+
input = rock::TransformOp::create(rewriter, loc, input, b.get());
493+
}
494+
495+
RankedTensorType linalgResultType =
496+
cast<RankedTensorType>(op.getResult(0).getType());
497+
SmallVector<int64_t> rockShape(linalgResultType.getShape());
498+
if (conv.spatialDim == 1)
499+
rockShape.push_back(1);
500+
RankedTensorType rockResultType =
501+
RankedTensorType::get(rockShape, linalgResultType.getElementType());
502+
Value output =
503+
bufferization::AllocTensorOp::create(rewriter, loc, rockResultType, {});
504+
auto cop = rock::ConvOp::create(rewriter, loc, rockResultType, filter, input,
505+
output, /*features=*/nullptr,
506+
/*blockSize=*/nullptr, /*gridSize=*/nullptr,
507+
conv.padding, conv.stride, conv.dilation,
508+
/*params=*/nullptr);
509+
// TODO: add splitk see (AIROCMLIR-696)
510+
if (conv.perfConfig)
511+
cop->setAttr("perf_config", conv.perfConfig);
512+
setConvLayoutAttrs(rewriter, cop, effectiveSpatialDim);
513+
514+
Value result = cop.getResult();
515+
if (conv.spatialDim == 1) {
516+
auto shape = cast<RankedTensorType>(result.getType()).getShape();
517+
rock::BottomUpTMBuilder b(rewriter, {"n", "g", "k", "0", "1"}, shape, loc);
518+
b.passThrough({"no", "go", "ko"}, {0, 1, 2}, {"n", "g", "k"});
519+
b.merge("0o", 3, {"0", "1"});
520+
result = rock::TransformOp::create(rewriter, loc, result, b.get());
521+
}
522+
523+
rewriter.replaceOp(op, result);
524+
return success();
525+
}
526+
262527
void mlir::rock::populateLinalgToRockConversionPattern(
263528
RewritePatternSet &pattern, MLIRContext *context) {
264529
pattern.add<MatmulConverter<linalg::BatchMatmulOp>,
265530
MatmulConverter<linalg::MatmulOp>, ExpandStrideConverter,
266-
MatmulConverter<linalg::GenericOp>>(context);
531+
MatmulConverter<linalg::GenericOp>, ConvLinalgConverter>(context);
267532
}

mlir/lib/Conversion/LinalgToRock/LinalgToRockPass.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,15 @@ static void populateLinalgToRockDialectConversion(ConversionTarget &target) {
5959
return false;
6060
}
6161

62-
return linalg::isElementwise(linalgOp) || isa<linalg::GenericOp>(op) ||
63-
isa<linalg::YieldOp>(op);
62+
// Convolution linalg.generic has reduction iteration type. It is not
63+
// a legal operation in that case
64+
linalg::GenericOp castedOp = dyn_cast<linalg::GenericOp>(op);
65+
if (castedOp && castedOp->hasAttr(rock::linalgConvOpAttrName)) {
66+
return false;
67+
}
68+
69+
return linalg::isElementwise(linalgOp) || isa<linalg::YieldOp>(op) ||
70+
castedOp;
6471
});
6572
}
6673

0 commit comments

Comments
 (0)