@@ -272,6 +272,9 @@ LogicalResult AsUnderlyingShapeConverter::matchAndRewrite(
272272 return success ();
273273}
274274
275+ // ===----------------------------------------------------------------------===//
276+ // Forward and Backward convolution converter
277+ // ===----------------------------------------------------------------------===//
275278namespace {
276279struct ConvConverter final
277280 : public OpConversionPattern<migraphx::ConvolutionOp> {
@@ -289,6 +292,24 @@ struct ConvConverter final
289292 migraphx::ConvolutionOp op, Value input,
290293 Value filter) const ;
291294};
295+
296+ struct BackwardConvConverter final
297+ : public OpConversionPattern<migraphx::ConvolutionBwdDataOp> {
298+ using OpConversionPattern<
299+ migraphx::ConvolutionBwdDataOp>::OpConversionPattern;
300+ using OpConversionPattern<migraphx::ConvolutionBwdDataOp>::getTypeConverter;
301+ using OpAdaptor =
302+ typename OpConversionPattern<migraphx::ConvolutionBwdDataOp>::OpAdaptor;
303+
304+ LogicalResult
305+ matchAndRewrite (migraphx::ConvolutionBwdDataOp op, OpAdaptor adaptor,
306+ ConversionPatternRewriter &rewriter) const override ;
307+
308+ private:
309+ LogicalResult emitBackwardConv (ConversionPatternRewriter &rewriter,
310+ migraphx::ConvolutionBwdDataOp op, Value input,
311+ Value filter) const ;
312+ };
292313} // namespace
293314
294315// Nice helper function for the linalg.generic op region
@@ -302,19 +323,20 @@ static void convBodyBuilder(OpBuilder &b, Location loc, ValueRange blockArgs) {
302323}
303324
304325// / Emit convolution attributes on the newly created operation.
305- static void emitConvAttributes (migraphx::ConvolutionOp op, Value convOp,
306- Attribute strides, Attribute dilation,
307- Attribute pad, Attribute convOpName) {
326+ static void emitConvAttributes (Value convOp, Attribute strides,
327+ Attribute dilation, Attribute pad,
328+ Attribute perfConfig, Attribute groupAttr,
329+ Attribute convOpName) {
308330 Operation *newOp = convOp.getDefiningOp ();
309331 newOp->setAttr (" pad" , pad);
310- newOp->setAttr (" group" , op. getGroupAttr () );
332+ newOp->setAttr (" group" , groupAttr );
311333 newOp->setAttr (" stride" , strides);
312334 newOp->setAttr (" dilation" , dilation);
313335
314336 // Convert optional attributes
315- if (auto attr = (*op). template getAttrOfType <StringAttr>( " perf_config " ) )
316- newOp->setAttr (" perf_config" , attr );
317- newOp->setAttr (" conv_op " , convOpName);
337+ if (perfConfig )
338+ newOp->setAttr (" perf_config" , perfConfig );
339+ newOp->setAttr (rock::linalgConvOpAttrName , convOpName);
318340}
319341
320342// / Emit a grouped convolution of any spatial rank (1D, 2D, or 3D).
@@ -403,6 +425,113 @@ static Value emitGroupedConv(ConversionPatternRewriter &rewriter, Location loc,
403425 .getResult (0 );
404426}
405427
428+ // / Emit a grouped backward (transposed) convolution of any spatial rank.
429+ // / Input shape: (batch, group, channel, spatial...),
430+ // / filter shape: (group, filter, channel, kernel_spatial...)
431+ // /
432+ // / The loop structure mirrors the forward convolution, but with the
433+ // / stride/dilation affine expression on the *output* indexing map:
434+ // /
435+ // / clang-format off
436+ // / for n in batch:
437+ // / for g in group:
438+ // / for ih_0 in input_spatial_0:
439+ // / for ih_1 in input_spatial_1:
440+ // / // ...
441+ // / for ih_{dim-1} in input_spatial_{dim-1}:
442+ // / for f in filters:
443+ // / reduction starts here
444+ // / for c in channels: // reduction
445+ // / for kh_0 in kernel_spatial_0: // reduction
446+ // / for kh_1 in kernel_spatial_1: // reduction
447+ // / // ...
448+ // / result[n,g,f, ih_i*stride_i + kh_i*dilation_i, ...] +=
449+ // / input[n,g,c,ih_0,...] * filter[g,c,f,kh_0,...]
450+ // / clang-format on
451+ static Value emitGroupedBackwardConv (ConversionPatternRewriter &rewriter,
452+ Location loc, RankedTensorType resultType,
453+ Value input, Value filter, Value zero,
454+ ArrayAttr strides, ArrayAttr dilation) {
455+ MLIRContext *ctx = rewriter.getContext ();
456+ int64_t spatialDim = cast<RankedTensorType>(input.getType ()).getRank () - 3 ;
457+ SmallVector<int64_t , 4 > strideVals;
458+ SmallVector<int64_t , 4 > dilationVals;
459+ llvm::transform (
460+ strides.getValue (), std::back_inserter (strideVals),
461+ [](Attribute attr) { return cast<IntegerAttr>(attr).getInt (); });
462+ llvm::transform (
463+ dilation.getValue (), std::back_inserter (dilationVals),
464+ [](Attribute attr) { return cast<IntegerAttr>(attr).getInt (); });
465+
466+ // Iteration domain layout (mirrors emitGroupedConv):
467+ // parallel: batch, group, ih_0 .. ih_{dim-1}, filter
468+ // reduction: channel, kh_0 .. kh_{dim-1}
469+ // See the loop structure from above to see where these constants come from
470+ const int64_t ihStart = 2 ;
471+ const int64_t filterIdx = ihStart + spatialDim;
472+ const int64_t channelIdx = filterIdx + 1 ;
473+ const int64_t khStart = channelIdx + 1 ;
474+ const int64_t totalDims = khStart + spatialDim;
475+ const int64_t numParallel = channelIdx;
476+
477+ SmallVector<AffineExpr> d;
478+ for (int64_t i = 0 ; i < totalDims; ++i)
479+ d.push_back (getAffineDimExpr (i, ctx));
480+
481+ AffineExpr batch = d[0 ], group = d[1 ];
482+ AffineExpr outChannel = d[filterIdx];
483+ AffineExpr inChannel = d[channelIdx];
484+
485+ SmallVector<AffineExpr> inputExprs = {batch, group, inChannel};
486+ for (int64_t i = 0 ; i < spatialDim; ++i)
487+ inputExprs.push_back (d[ihStart + i]);
488+
489+ SmallVector<AffineExpr> filterExprs = {group, inChannel, outChannel};
490+ for (int64_t i = 0 ; i < spatialDim; ++i)
491+ filterExprs.push_back (d[khStart + i]);
492+
493+ SmallVector<AffineExpr> outputExprs = {batch, group, outChannel};
494+ for (int64_t i = 0 ; i < spatialDim; ++i) {
495+ AffineExpr ih_i = d[ihStart + i];
496+ AffineExpr kh_i = d[khStart + i];
497+ outputExprs.push_back (ih_i * strideVals[i] + kh_i * dilationVals[i]);
498+ }
499+
500+ SmallVector<AffineMap> indexingMaps = {
501+ AffineMap::get (totalDims, /* symbolCount=*/ 0 , inputExprs, ctx),
502+ AffineMap::get (totalDims, /* symbolCount=*/ 0 , filterExprs, ctx),
503+ AffineMap::get (totalDims, /* symbolCount=*/ 0 , outputExprs, ctx)};
504+
505+ SmallVector<utils::IteratorType> iteratorTypes (numParallel,
506+ utils::IteratorType::parallel);
507+ iteratorTypes.append (totalDims - numParallel, utils::IteratorType::reduction);
508+
509+ auto result = linalg::GenericOp::create (
510+ rewriter, loc, resultType, ValueRange{input, filter}, zero,
511+ indexingMaps, iteratorTypes, convBodyBuilder)
512+ .getResult (0 );
513+ return result;
514+ }
515+
516+ // / Given the collapsed NF* result type and the group count, return the
517+ // / expanded NGK* result type for the grouped linalg convolution.
518+ static RankedTensorType expandResultForGroupedConv (RankedTensorType resultType,
519+ int64_t group) {
520+ ArrayRef<int64_t > resultShape = resultType.getShape ();
521+ int64_t n = resultType.getDimSize (0 );
522+ int64_t newF = resultType.getDimSize (1 ) / group;
523+ assert (resultType.getDimSize (1 ) % group == 0 &&
524+ " output channel must be divisible by group" );
525+
526+ SmallVector<int64_t , 4 > newShape;
527+ newShape.push_back (n);
528+ newShape.push_back (group);
529+ newShape.push_back (newF);
530+ newShape.insert (newShape.end (), std::next (resultShape.begin (), 2 ),
531+ resultShape.end ());
532+ return RankedTensorType::get (newShape, resultType.getElementType ());
533+ }
534+
406535LogicalResult ConvConverter::emitConv (ConversionPatternRewriter &rewriter,
407536 migraphx::ConvolutionOp op, Value input,
408537 Value filter) const {
@@ -444,7 +573,8 @@ LogicalResult ConvConverter::emitConv(ConversionPatternRewriter &rewriter,
444573 Value result = emitGroupedConv (rewriter, loc, newResultType, input, filter,
445574 zero, strides, dilation);
446575
447- emitConvAttributes (op, result, strides, dilation, op.getPaddingAttr (),
576+ emitConvAttributes (result, strides, dilation, op.getPaddingAttr (),
577+ op->getAttr (" perf_config" ), op.getGroupAttr (),
448578 resultConvOpName);
449579
450580 // we must reshape the operand to what the type converter expects
@@ -608,6 +738,116 @@ ConvConverter::matchAndRewrite(migraphx::ConvolutionOp op, OpAdaptor adaptor,
608738}
609739
610740// TODO: migraphx::DeQuantizeLinearConverter
741+ LogicalResult
742+ BackwardConvConverter::emitBackwardConv (ConversionPatternRewriter &rewriter,
743+ migraphx::ConvolutionBwdDataOp op,
744+ Value input, Value filter) const {
745+ Location loc = op.getLoc ();
746+ int64_t group = op.getGroupAttr ().getInt ();
747+ int64_t spatialDim = cast<RankedTensorType>(input.getType ()).getRank () -
748+ 3 ; // exclude batch (N), group (G), channel (C)
749+ if (spatialDim > 3 )
750+ return op.emitError (" only support 1D to 3D conv_bwd" );
751+
752+ // To get the result shape, we must first add the padding
753+ ArrayRef<Attribute> padding = op.getPaddingAttr ().getValue ();
754+ RankedTensorType originalResult =
755+ cast<RankedTensorType>(getTypeConverter ()->convertType (op.getResult ()));
756+ SmallVector<int64_t , 4 > resultShape (originalResult.getShape ());
757+ SmallVector<int64_t , 4 > lowPads;
758+ SmallVector<int64_t , 4 > highPads;
759+ for (int64_t i = 0 ; i < spatialDim; ++i) {
760+ int64_t lowPad = cast<IntegerAttr>(padding[i]).getInt ();
761+ int64_t highPad = cast<IntegerAttr>(padding[i + spatialDim]).getInt ();
762+ // The first two dimension of the result is batch and channel, and we apply
763+ // padding to the spatial dimension
764+ resultShape[2 + i] += lowPad + highPad;
765+ lowPads.push_back (lowPad);
766+ highPads.push_back (highPad);
767+ }
768+ RankedTensorType resultType =
769+ RankedTensorType::get (resultShape, originalResult.getElementType ());
770+ auto newResultType = expandResultForGroupedConv (resultType, group);
771+ Value zero = arith::ConstantOp::create (rewriter, loc, newResultType,
772+ rewriter.getZeroAttr (newResultType));
773+
774+ ArrayAttr strides = op.getStride ();
775+ ArrayAttr dilation = op.getDilation ();
776+
777+ Value result = emitGroupedBackwardConv (rewriter, loc, newResultType, input,
778+ filter, zero, strides, dilation);
779+ rock::LinalgConvType convType =
780+ (spatialDim == 3 ) ? rock::LinalgConvType::Conv3dBWDNgchwdGckhwd
781+ : (spatialDim == 2 ) ? rock::LinalgConvType::Conv2dBWDNgchwGckhw
782+ : rock::LinalgConvType::Conv1dBWDNgchGckh;
783+ emitConvAttributes (
784+ result, strides, dilation, op.getPaddingAttr (),
785+ op->getAttr (" perf_config" ), op.getGroupAttr (),
786+ rock::LinalgConvTypeAttr::get (rewriter.getContext (), convType));
787+
788+ // Collapse result from NGK* back to NK*
789+ SmallVector<ReassociationIndices, 4 > reassociation{{0 }, {1 , 2 }};
790+ llvm::for_each (llvm::seq<int64_t >(3 , spatialDim + 3 ),
791+ [&](int64_t index) { reassociation.push_back ({index}); });
792+ auto finalResult =
793+ tensor::CollapseShapeOp::create (rewriter, loc, result, reassociation)
794+ .getResult ();
795+
796+ bool hasPadding = llvm::any_of (lowPads, [](int64_t p) { return p != 0 ; }) ||
797+ llvm::any_of (highPads, [](int64_t p) { return p != 0 ; });
798+ if (hasPadding) {
799+ int64_t rank = originalResult.getRank ();
800+ SmallVector<OpFoldResult> offsets (rank, rewriter.getIndexAttr (0 ));
801+ SmallVector<OpFoldResult> sizes;
802+ SmallVector<OpFoldResult> strides (rank, rewriter.getIndexAttr (1 ));
803+ for (int64_t i = 0 ; i < rank; ++i)
804+ sizes.push_back (rewriter.getIndexAttr (originalResult.getDimSize (i)));
805+ for (int64_t i = 0 ; i < spatialDim; ++i)
806+ offsets[2 + i] = rewriter.getIndexAttr (lowPads[i]);
807+ finalResult =
808+ tensor::ExtractSliceOp::create (rewriter, loc, originalResult,
809+ finalResult, offsets, sizes, strides)
810+ .getResult ();
811+ }
812+
813+ rewriter.replaceOp (op, finalResult);
814+ return success ();
815+ }
816+
817+ LogicalResult BackwardConvConverter::matchAndRewrite (
818+ migraphx::ConvolutionBwdDataOp op, OpAdaptor adaptor,
819+ ConversionPatternRewriter &rewriter) const {
820+ // Backward convolution lowering is similar to forward convolution and is
821+ // lowered in three steps:
822+ // 1. Expand the channel dimension into (group, channel_per_group),
823+ // introducing
824+ // a group dimension G. Input becomes NGC* (e.g. NGCL, NGCHW, NGCDHW) and
825+ // filter becomes GFC* (e.g. GFCL, GFCHW, GFCDHW), matching the group attr.
826+ // 2. Emit the grouped linalg convolution (1D/2D/3D), then collapse the
827+ // result back to the original NFHW/NFDHW shape for the type converter.
828+ Location loc = op.getLoc ();
829+ Value input = adaptor.getInput ();
830+ Value filter = adaptor.getFilter ();
831+ RankedTensorType inputType = cast<RankedTensorType>(input.getType ());
832+ int64_t dim = inputType.getRank () - 2 ;
833+ int64_t group = op.getGroupAttr ().getInt ();
834+
835+ if (dim > 3 || dim < 1 ) {
836+ return op.emitError (Twine (dim) + " D conv is not supported for now" );
837+ }
838+
839+ if (inputType.getElementType () != op.getFilter ().getType ().getElementType () ||
840+ inputType.getElementType () != op.getResult ().getType ().getElementType ()) {
841+ return op.emitError (
842+ " type casting between operands and result is unsupported for now" );
843+ }
844+
845+ input = expandGroupDim (rewriter, loc, input, /* isFilter=*/ false , group, dim);
846+ filter = expandGroupDim (rewriter, loc, filter, /* isFilter=*/ true , group, dim);
847+
848+ return emitBackwardConv (rewriter, op, input, filter);
849+ }
850+
611851// ===----------------------------------------------------------------------===//
612852// Base kernels (gemm)
613853// ===----------------------------------------------------------------------===//
@@ -1583,8 +1823,8 @@ void mlir::migraphx::populateMIGraphXToLinalgConversionPatterns(
15831823 LiteralConverter, ReshapeConverter,
15841824 BooleanElementwiseConverter<migraphx::Greater>,
15851825 BooleanElementwiseConverter<migraphx::Equal>, ClipConverter,
1586- TransposeConverter, ConvConverter, SliceConverter>(
1587- converter, patterns.getContext ());
1826+ TransposeConverter, ConvConverter, SliceConverter,
1827+ BackwardConvConverter>( converter, patterns.getContext ());
15881828}
15891829
15901830void mlir::migraphx::populateMIGraphXFuncBoundaryToLinalgConversionPatterns (
0 commit comments