Skip to content

Commit 59c07cf

Browse files
authored
[AIROCMLIR-445] Lower migraphx.backwards_data_convolution (#2256)
Being able to support 1D,2D,3D backwards convolution with for CPU with all the attributes (strides, dilation, padding, group). For backwards convolution is similar to forward convolution, and it into these steps: 1. Expand the input and the kernel in the G dimension. Input goes from NC* into NGC*, and filter goes form CF* into GCF* (or FC* to GFC* depending on how you look at it - backwards convolution naming scheme is a bit confusing). 2. Emit the linalg.generic that should be like the python loop below. 3. Implements padding through `tensor.extract_slice` ```python def my_grouped_impl(input_np, filter_np, output_shape, stride=(1, 1), dilation=(1, 1)): result = np.zeros(output_shape, dtype=np.float32) batch, group, channel, input_height, input_width = input_np.shape group, filter_count, channel, filter_height, filter_width = filter_np.shape for n in range(batch):# 0 for g in range(group):# 1 for hi in range(input_height):# 2 for wi in range(input_width): # 3 for f in range(filter_count):# 4 # reduction starts here! for c in range(channel):# 5 for hk in range(filter_height):# 6 for wk in range(filter_width):# 7 height_access = hi*stride[0] + dilation[0] * hk width_access = wi*stride[1] + dilation[1] * wk result[n, g, f, height_access, width_access] += input_np[n, g, c, hi, wi] * filter_np[c, g, f, hk, wk] return result ``` If we have padding, it should look like the following code structure: ```mlir linalg.generic ins (...) outs(%first) ... // computing the transposed kernel %output = tensor.extract_slice %first [0, 0, padLow0, padLow1, ....][N, F, Ho, Wo, ....][1, 1, 1, ....] // apply padding ```
1 parent c7445a2 commit 59c07cf

13 files changed

Lines changed: 321 additions & 18 deletions

mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,19 @@ def LinalgConv_2D
6565
: I32EnumAttrCase<"Conv2dNgchwGkchw", 1, "conv2d_ngchw_gkchw">;
6666
def LinalgConv_3D
6767
: I32EnumAttrCase<"Conv3dNgchwdGkchwd", 2, "conv3d_ngchwd_gkchwd">;
68+
def LinalgBwdConv1D
69+
: I32EnumAttrCase<"Conv1dBWDNgchGckh", 3, "convbwd1d_ngch_gckh">;
70+
def LinalgBwdConv2D
71+
: I32EnumAttrCase<"Conv2dBWDNgchwGckhw", 4, "convbwd2d_ngchw_gckhw">;
72+
def LinalgBwdConv3D
73+
: I32EnumAttrCase<"Conv3dBWDNgchwdGckhwd", 5, "convbwd3d_ngchwd_gckhwd">;
6874

6975
def LinalgConvType
7076
: Rock_I32Enum<"LinalgConvType",
7177
"Hints for the linalg.generic convolution ops used by "
7278
"linalg-to-rock lowering",
73-
[LinalgConv_1D, LinalgConv_2D, LinalgConv_3D]>;
79+
[LinalgConv_1D, LinalgConv_2D, LinalgConv_3D,
80+
LinalgBwdConv1D, LinalgBwdConv2D, LinalgBwdConv3D]>;
7481

7582
def LinalgConvTypeAttr
7683
: EnumAttr<Rock_Dialect, LinalgConvType, "LinalgConvType">;

mlir/lib/Conversion/LinalgToRock/LinalgToRock.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,9 @@ static int64_t getSpatialDim(rock::LinalgConvType type) {
280280
return 2;
281281
case rock::LinalgConvType::Conv3dNgchwdGkchwd:
282282
return 3;
283+
default:
284+
llvm_unreachable("unknown LinalgConvType");
283285
}
284-
llvm_unreachable("unknown LinalgConvType");
285286
}
286287

287288
/// Set filter_layout, input_layout, and output_layout on a rock.conv op.

mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.cpp

Lines changed: 250 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,9 @@ LogicalResult AsUnderlyingShapeConverter::matchAndRewrite(
272272
return success();
273273
}
274274

275+
//===----------------------------------------------------------------------===//
276+
// Forward and Backward convolution converter
277+
//===----------------------------------------------------------------------===//
275278
namespace {
276279
struct ConvConverter final
277280
: public OpConversionPattern<migraphx::ConvolutionOp> {
@@ -289,6 +292,24 @@ struct ConvConverter final
289292
migraphx::ConvolutionOp op, Value input,
290293
Value filter) const;
291294
};
295+
296+
struct BackwardConvConverter final
297+
: public OpConversionPattern<migraphx::ConvolutionBwdDataOp> {
298+
using OpConversionPattern<
299+
migraphx::ConvolutionBwdDataOp>::OpConversionPattern;
300+
using OpConversionPattern<migraphx::ConvolutionBwdDataOp>::getTypeConverter;
301+
using OpAdaptor =
302+
typename OpConversionPattern<migraphx::ConvolutionBwdDataOp>::OpAdaptor;
303+
304+
LogicalResult
305+
matchAndRewrite(migraphx::ConvolutionBwdDataOp op, OpAdaptor adaptor,
306+
ConversionPatternRewriter &rewriter) const override;
307+
308+
private:
309+
LogicalResult emitBackwardConv(ConversionPatternRewriter &rewriter,
310+
migraphx::ConvolutionBwdDataOp op, Value input,
311+
Value filter) const;
312+
};
292313
} // namespace
293314

294315
// Nice helper function for the linalg.generic op region
@@ -302,19 +323,20 @@ static void convBodyBuilder(OpBuilder &b, Location loc, ValueRange blockArgs) {
302323
}
303324

304325
/// Emit convolution attributes on the newly created operation.
305-
static void emitConvAttributes(migraphx::ConvolutionOp op, Value convOp,
306-
Attribute strides, Attribute dilation,
307-
Attribute pad, Attribute convOpName) {
326+
static void emitConvAttributes(Value convOp, Attribute strides,
327+
Attribute dilation, Attribute pad,
328+
Attribute perfConfig, Attribute groupAttr,
329+
Attribute convOpName) {
308330
Operation *newOp = convOp.getDefiningOp();
309331
newOp->setAttr("pad", pad);
310-
newOp->setAttr("group", op.getGroupAttr());
332+
newOp->setAttr("group", groupAttr);
311333
newOp->setAttr("stride", strides);
312334
newOp->setAttr("dilation", dilation);
313335

314336
// Convert optional attributes
315-
if (auto attr = (*op).template getAttrOfType<StringAttr>("perf_config"))
316-
newOp->setAttr("perf_config", attr);
317-
newOp->setAttr("conv_op", convOpName);
337+
if (perfConfig)
338+
newOp->setAttr("perf_config", perfConfig);
339+
newOp->setAttr(rock::linalgConvOpAttrName, convOpName);
318340
}
319341

320342
/// Emit a grouped convolution of any spatial rank (1D, 2D, or 3D).
@@ -403,6 +425,113 @@ static Value emitGroupedConv(ConversionPatternRewriter &rewriter, Location loc,
403425
.getResult(0);
404426
}
405427

428+
/// Emit a grouped backward (transposed) convolution of any spatial rank.
429+
/// Input shape: (batch, group, channel, spatial...),
430+
/// filter shape: (group, filter, channel, kernel_spatial...)
431+
///
432+
/// The loop structure mirrors the forward convolution, but with the
433+
/// stride/dilation affine expression on the *output* indexing map:
434+
///
435+
/// clang-format off
436+
/// for n in batch:
437+
/// for g in group:
438+
/// for ih_0 in input_spatial_0:
439+
/// for ih_1 in input_spatial_1:
440+
/// // ...
441+
/// for ih_{dim-1} in input_spatial_{dim-1}:
442+
/// for f in filters:
443+
/// reduction starts here
444+
/// for c in channels: // reduction
445+
/// for kh_0 in kernel_spatial_0: // reduction
446+
/// for kh_1 in kernel_spatial_1: // reduction
447+
/// // ...
448+
/// result[n,g,f, ih_i*stride_i + kh_i*dilation_i, ...] +=
449+
/// input[n,g,c,ih_0,...] * filter[g,c,f,kh_0,...]
450+
/// clang-format on
451+
static Value emitGroupedBackwardConv(ConversionPatternRewriter &rewriter,
452+
Location loc, RankedTensorType resultType,
453+
Value input, Value filter, Value zero,
454+
ArrayAttr strides, ArrayAttr dilation) {
455+
MLIRContext *ctx = rewriter.getContext();
456+
int64_t spatialDim = cast<RankedTensorType>(input.getType()).getRank() - 3;
457+
SmallVector<int64_t, 4> strideVals;
458+
SmallVector<int64_t, 4> dilationVals;
459+
llvm::transform(
460+
strides.getValue(), std::back_inserter(strideVals),
461+
[](Attribute attr) { return cast<IntegerAttr>(attr).getInt(); });
462+
llvm::transform(
463+
dilation.getValue(), std::back_inserter(dilationVals),
464+
[](Attribute attr) { return cast<IntegerAttr>(attr).getInt(); });
465+
466+
// Iteration domain layout (mirrors emitGroupedConv):
467+
// parallel: batch, group, ih_0 .. ih_{dim-1}, filter
468+
// reduction: channel, kh_0 .. kh_{dim-1}
469+
// See the loop structure from above to see where these constants come from
470+
const int64_t ihStart = 2;
471+
const int64_t filterIdx = ihStart + spatialDim;
472+
const int64_t channelIdx = filterIdx + 1;
473+
const int64_t khStart = channelIdx + 1;
474+
const int64_t totalDims = khStart + spatialDim;
475+
const int64_t numParallel = channelIdx;
476+
477+
SmallVector<AffineExpr> d;
478+
for (int64_t i = 0; i < totalDims; ++i)
479+
d.push_back(getAffineDimExpr(i, ctx));
480+
481+
AffineExpr batch = d[0], group = d[1];
482+
AffineExpr outChannel = d[filterIdx];
483+
AffineExpr inChannel = d[channelIdx];
484+
485+
SmallVector<AffineExpr> inputExprs = {batch, group, inChannel};
486+
for (int64_t i = 0; i < spatialDim; ++i)
487+
inputExprs.push_back(d[ihStart + i]);
488+
489+
SmallVector<AffineExpr> filterExprs = {group, inChannel, outChannel};
490+
for (int64_t i = 0; i < spatialDim; ++i)
491+
filterExprs.push_back(d[khStart + i]);
492+
493+
SmallVector<AffineExpr> outputExprs = {batch, group, outChannel};
494+
for (int64_t i = 0; i < spatialDim; ++i) {
495+
AffineExpr ih_i = d[ihStart + i];
496+
AffineExpr kh_i = d[khStart + i];
497+
outputExprs.push_back(ih_i * strideVals[i] + kh_i * dilationVals[i]);
498+
}
499+
500+
SmallVector<AffineMap> indexingMaps = {
501+
AffineMap::get(totalDims, /*symbolCount=*/0, inputExprs, ctx),
502+
AffineMap::get(totalDims, /*symbolCount=*/0, filterExprs, ctx),
503+
AffineMap::get(totalDims, /*symbolCount=*/0, outputExprs, ctx)};
504+
505+
SmallVector<utils::IteratorType> iteratorTypes(numParallel,
506+
utils::IteratorType::parallel);
507+
iteratorTypes.append(totalDims - numParallel, utils::IteratorType::reduction);
508+
509+
auto result = linalg::GenericOp::create(
510+
rewriter, loc, resultType, ValueRange{input, filter}, zero,
511+
indexingMaps, iteratorTypes, convBodyBuilder)
512+
.getResult(0);
513+
return result;
514+
}
515+
516+
/// Given the collapsed NF* result type and the group count, return the
517+
/// expanded NGK* result type for the grouped linalg convolution.
518+
static RankedTensorType expandResultForGroupedConv(RankedTensorType resultType,
519+
int64_t group) {
520+
ArrayRef<int64_t> resultShape = resultType.getShape();
521+
int64_t n = resultType.getDimSize(0);
522+
int64_t newF = resultType.getDimSize(1) / group;
523+
assert(resultType.getDimSize(1) % group == 0 &&
524+
"output channel must be divisible by group");
525+
526+
SmallVector<int64_t, 4> newShape;
527+
newShape.push_back(n);
528+
newShape.push_back(group);
529+
newShape.push_back(newF);
530+
newShape.insert(newShape.end(), std::next(resultShape.begin(), 2),
531+
resultShape.end());
532+
return RankedTensorType::get(newShape, resultType.getElementType());
533+
}
534+
406535
LogicalResult ConvConverter::emitConv(ConversionPatternRewriter &rewriter,
407536
migraphx::ConvolutionOp op, Value input,
408537
Value filter) const {
@@ -444,7 +573,8 @@ LogicalResult ConvConverter::emitConv(ConversionPatternRewriter &rewriter,
444573
Value result = emitGroupedConv(rewriter, loc, newResultType, input, filter,
445574
zero, strides, dilation);
446575

447-
emitConvAttributes(op, result, strides, dilation, op.getPaddingAttr(),
576+
emitConvAttributes(result, strides, dilation, op.getPaddingAttr(),
577+
op->getAttr("perf_config"), op.getGroupAttr(),
448578
resultConvOpName);
449579

450580
// we must reshape the operand to what the type converter expects
@@ -608,6 +738,116 @@ ConvConverter::matchAndRewrite(migraphx::ConvolutionOp op, OpAdaptor adaptor,
608738
}
609739

610740
// TODO: migraphx::DeQuantizeLinearConverter
741+
LogicalResult
742+
BackwardConvConverter::emitBackwardConv(ConversionPatternRewriter &rewriter,
743+
migraphx::ConvolutionBwdDataOp op,
744+
Value input, Value filter) const {
745+
Location loc = op.getLoc();
746+
int64_t group = op.getGroupAttr().getInt();
747+
int64_t spatialDim = cast<RankedTensorType>(input.getType()).getRank() -
748+
3; // exclude batch (N), group (G), channel (C)
749+
if (spatialDim > 3)
750+
return op.emitError("only support 1D to 3D conv_bwd");
751+
752+
// To get the result shape, we must first add the padding
753+
ArrayRef<Attribute> padding = op.getPaddingAttr().getValue();
754+
RankedTensorType originalResult =
755+
cast<RankedTensorType>(getTypeConverter()->convertType(op.getResult()));
756+
SmallVector<int64_t, 4> resultShape(originalResult.getShape());
757+
SmallVector<int64_t, 4> lowPads;
758+
SmallVector<int64_t, 4> highPads;
759+
for (int64_t i = 0; i < spatialDim; ++i) {
760+
int64_t lowPad = cast<IntegerAttr>(padding[i]).getInt();
761+
int64_t highPad = cast<IntegerAttr>(padding[i + spatialDim]).getInt();
762+
// The first two dimension of the result is batch and channel, and we apply
763+
// padding to the spatial dimension
764+
resultShape[2 + i] += lowPad + highPad;
765+
lowPads.push_back(lowPad);
766+
highPads.push_back(highPad);
767+
}
768+
RankedTensorType resultType =
769+
RankedTensorType::get(resultShape, originalResult.getElementType());
770+
auto newResultType = expandResultForGroupedConv(resultType, group);
771+
Value zero = arith::ConstantOp::create(rewriter, loc, newResultType,
772+
rewriter.getZeroAttr(newResultType));
773+
774+
ArrayAttr strides = op.getStride();
775+
ArrayAttr dilation = op.getDilation();
776+
777+
Value result = emitGroupedBackwardConv(rewriter, loc, newResultType, input,
778+
filter, zero, strides, dilation);
779+
rock::LinalgConvType convType =
780+
(spatialDim == 3) ? rock::LinalgConvType::Conv3dBWDNgchwdGckhwd
781+
: (spatialDim == 2) ? rock::LinalgConvType::Conv2dBWDNgchwGckhw
782+
: rock::LinalgConvType::Conv1dBWDNgchGckh;
783+
emitConvAttributes(
784+
result, strides, dilation, op.getPaddingAttr(),
785+
op->getAttr("perf_config"), op.getGroupAttr(),
786+
rock::LinalgConvTypeAttr::get(rewriter.getContext(), convType));
787+
788+
// Collapse result from NGK* back to NK*
789+
SmallVector<ReassociationIndices, 4> reassociation{{0}, {1, 2}};
790+
llvm::for_each(llvm::seq<int64_t>(3, spatialDim + 3),
791+
[&](int64_t index) { reassociation.push_back({index}); });
792+
auto finalResult =
793+
tensor::CollapseShapeOp::create(rewriter, loc, result, reassociation)
794+
.getResult();
795+
796+
bool hasPadding = llvm::any_of(lowPads, [](int64_t p) { return p != 0; }) ||
797+
llvm::any_of(highPads, [](int64_t p) { return p != 0; });
798+
if (hasPadding) {
799+
int64_t rank = originalResult.getRank();
800+
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
801+
SmallVector<OpFoldResult> sizes;
802+
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
803+
for (int64_t i = 0; i < rank; ++i)
804+
sizes.push_back(rewriter.getIndexAttr(originalResult.getDimSize(i)));
805+
for (int64_t i = 0; i < spatialDim; ++i)
806+
offsets[2 + i] = rewriter.getIndexAttr(lowPads[i]);
807+
finalResult =
808+
tensor::ExtractSliceOp::create(rewriter, loc, originalResult,
809+
finalResult, offsets, sizes, strides)
810+
.getResult();
811+
}
812+
813+
rewriter.replaceOp(op, finalResult);
814+
return success();
815+
}
816+
817+
LogicalResult BackwardConvConverter::matchAndRewrite(
818+
migraphx::ConvolutionBwdDataOp op, OpAdaptor adaptor,
819+
ConversionPatternRewriter &rewriter) const {
820+
// Backward convolution lowering is similar to forward convolution and is
821+
// lowered in three steps:
822+
// 1. Expand the channel dimension into (group, channel_per_group),
823+
// introducing
824+
// a group dimension G. Input becomes NGC* (e.g. NGCL, NGCHW, NGCDHW) and
825+
// filter becomes GFC* (e.g. GFCL, GFCHW, GFCDHW), matching the group attr.
826+
// 2. Emit the grouped linalg convolution (1D/2D/3D), then collapse the
827+
// result back to the original NFHW/NFDHW shape for the type converter.
828+
Location loc = op.getLoc();
829+
Value input = adaptor.getInput();
830+
Value filter = adaptor.getFilter();
831+
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
832+
int64_t dim = inputType.getRank() - 2;
833+
int64_t group = op.getGroupAttr().getInt();
834+
835+
if (dim > 3 || dim < 1) {
836+
return op.emitError(Twine(dim) + "D conv is not supported for now");
837+
}
838+
839+
if (inputType.getElementType() != op.getFilter().getType().getElementType() ||
840+
inputType.getElementType() != op.getResult().getType().getElementType()) {
841+
return op.emitError(
842+
"type casting between operands and result is unsupported for now");
843+
}
844+
845+
input = expandGroupDim(rewriter, loc, input, /*isFilter=*/false, group, dim);
846+
filter = expandGroupDim(rewriter, loc, filter, /*isFilter=*/true, group, dim);
847+
848+
return emitBackwardConv(rewriter, op, input, filter);
849+
}
850+
611851
//===----------------------------------------------------------------------===//
612852
// Base kernels (gemm)
613853
//===----------------------------------------------------------------------===//
@@ -1583,8 +1823,8 @@ void mlir::migraphx::populateMIGraphXToLinalgConversionPatterns(
15831823
LiteralConverter, ReshapeConverter,
15841824
BooleanElementwiseConverter<migraphx::Greater>,
15851825
BooleanElementwiseConverter<migraphx::Equal>, ClipConverter,
1586-
TransposeConverter, ConvConverter, SliceConverter>(
1587-
converter, patterns.getContext());
1826+
TransposeConverter, ConvConverter, SliceConverter,
1827+
BackwardConvConverter>(converter, patterns.getContext());
15881828
}
15891829

15901830
void mlir::migraphx::populateMIGraphXFuncBoundaryToLinalgConversionPatterns(
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// RUN: rocmlir-opt -split-input-file --migraphx-to-linalg --canonicalize --cse --remove-dead-values %s | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @mlir_bwd_data_conv(
4+
// CHECK-SAME: %[[arg0:.*]]: tensor{{.*}}, %[[arg1:.*]]: tensor{{.*}})
5+
// CHECK-DAG: %[[cst:.*]] = arith.constant
6+
// CHECK-DAG: %[[expanded:.*]] = tensor.expand_shape %[[arg0]]
7+
// CHECK-DAG: %[[expanded_0:.*]] = tensor.expand_shape %[[arg1]]
8+
// CHECK-DAG: %[[conv:.*]] = linalg.generic {{.*}} ins(%[[expanded]], %[[expanded_0]] : tensor{{.*}}) outs(%[[cst]] : tensor{{.*}})
9+
// CHECK-SAME: attrs = {conv_op = #rock<LinalgConvType convbwd2d_ngchw_gckhw>, dilation = [1, 1], group = 1 : i64, pad = [1, 1, 1, 1], stride = [2, 3]}
10+
// CHECK-DAG: %[[collapsed:.*]] = tensor.collapse_shape %[[conv]]
11+
// CHECK-DAG: %[[extracted_slice:.*]] = tensor.extract_slice %[[collapsed]]
12+
// CHECK-DAG: %[[collapsed_1:.*]] = tensor.collapse_shape %[[extracted_slice]]
13+
// CHECK-DAG: return %[[collapsed_1]]
14+
func.func @mlir_bwd_data_conv(
15+
%arg0: !migraphx.shaped<1x3x6x7xf32, 126x42x7x1>,
16+
%arg1: !migraphx.shaped<3x4x3x3xf32, 36x9x3x1>
17+
) -> !migraphx.shaped<1x4x11x19xf32, 836x209x19x1> {
18+
%0 = migraphx.backwards_data_convolution %arg0, %arg1 {
19+
dilation = [1, 1],
20+
group = 1 : i64,
21+
padding = [1, 1, 1, 1],
22+
padding_mode = 0 : i64,
23+
stride = [2, 3]} : <1x3x6x7xf32, 126x42x7x1>, <3x4x3x3xf32, 36x9x3x1> -> <1x4x11x19xf32, 836x209x19x1>
24+
return %0 : !migraphx.shaped<1x4x11x19xf32, 836x209x19x1>
25+
}
26+
27+
// -----
28+
29+
// Output grad: NCDHW = 1x1x1x3x3, Filter: CKDHW = 1x1x1x3x3
30+
// stride=[1,1,1], dilation=[1,1,1], padding=[0,0,0,0,0,0], group=1
31+
// CHECK-LABEL: func.func @mlir_bwd_data_conv(
32+
// CHECK-SAME: %[[arg0:.*]]: tensor{{.*}}, %[[arg1:.*]]: tensor{{.*}})
33+
// CHECK-DAG: %[[cst:.*]] = arith.constant
34+
// CHECK-DAG: %[[expanded:.*]] = tensor.expand_shape %[[arg1]]
35+
// CHECK-DAG: %[[expanded_0:.*]] = tensor.expand_shape %[[arg0]]
36+
// CHECK-DAG: %[[conv:.*]] = linalg.generic {{.*}} ins(%[[expanded]], %[[expanded_0]] : tensor{{.*}}) outs(%[[cst]] : tensor{{.*}})
37+
// CHECK-SAME: attrs = {conv_op = #rock<LinalgConvType convbwd3d_ngchwd_gckhwd>, dilation = [1, 1, 1], group = 1 : i64, pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1]}
38+
// CHECK-DAG: %[[collapsed:.*]] = tensor.collapse_shape %[[conv]]
39+
// CHECK-DAG: return %[[collapsed]]
40+
func.func @mlir_bwd_data_conv(
41+
%arg0: !migraphx.shaped<1x1x1x3x3xf32, 9x9x9x3x1>,
42+
%arg1: !migraphx.shaped<1x1x1x3x3xf32, 9x9x9x3x1>
43+
) -> !migraphx.shaped<1x1x1x5x5xf32, 25x25x25x5x1> attributes {rock.arch = "##TOKEN_ARCH##", rock.kernel} {
44+
%0 = migraphx.backwards_data_convolution %arg1, %arg0 {
45+
dilation = [1, 1, 1],
46+
group = 1 : i64,
47+
padding = [0, 0, 0, 0, 0, 0],
48+
padding_mode = 0 : i64,
49+
stride = [1, 1, 1]
50+
} : <1x1x1x3x3xf32, 9x9x9x3x1>, <1x1x1x3x3xf32, 9x9x9x3x1> -> <1x1x1x5x5xf32, 25x25x25x5x1>
51+
return %0 : !migraphx.shaped<1x1x1x5x5xf32, 25x25x25x5x1>
52+
}

0 commit comments

Comments
 (0)