From 8bbd978393431e6e21fb634034b310524e540408 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Wed, 6 May 2026 02:20:28 +0000 Subject: [PATCH 1/2] Port changes to rocMLIR --- .../RocmlirCustomTosaDecompose.cpp | 169 +++++++++--------- .../Dialect/Rock/Transforms/ConvToGemm.cpp | 20 ++- .../Dialect/Rock/utility/loweringUtils.cpp | 17 +- .../rocmlir-custom-tosa-decompose.mlir | 55 ++++++ ...v_to_gemm_bwd_data_empty_filter_slice.mlir | 45 +++++ ...mixr-bwd-data-conv-empty-filter-slice.mlir | 23 +++ 6 files changed, 235 insertions(+), 94 deletions(-) create mode 100644 mlir/test/Dialect/Rock/conv_to_gemm_bwd_data_empty_filter_slice.mlir create mode 100644 mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-empty-filter-slice.mlir diff --git a/mlir/lib/Conversion/RocmlirCustomTosaDecompose/RocmlirCustomTosaDecompose.cpp b/mlir/lib/Conversion/RocmlirCustomTosaDecompose/RocmlirCustomTosaDecompose.cpp index 5e7e1c6830f9..8258907c00a8 100644 --- a/mlir/lib/Conversion/RocmlirCustomTosaDecompose/RocmlirCustomTosaDecompose.cpp +++ b/mlir/lib/Conversion/RocmlirCustomTosaDecompose/RocmlirCustomTosaDecompose.cpp @@ -427,7 +427,7 @@ class TransposeConvStridedConverter : public OpRewritePattern { ShapedType inputTy = cast(input.getType()); ShapedType weightTy = cast(weight.getType()); ShapedType biasTy = cast(bias.getType()); - ShapedType resultTy = cast(op->getResult(0).getType()); + auto resultTy = cast(op->getResult(0).getType()); Type inputETy = inputTy.getElementType(); Type weightETy = weightTy.getElementType(); @@ -596,11 +596,6 @@ class TransposeConvStridedConverter : public OpRewritePattern { dilationVals = {1, 1}; } - // We want to capture the height and width values after dilation expansion, - // but before padding is added later on. - int64_t origWeightHeight = weightHeight; - int64_t origWeightWidth = weightWidth; - // Pad the weight so that it is modulo of the striding. llvm::SmallVector weightPadding = {0, 0, 0, 0, 0, 0, 0, 0}; weightPadding[3] = @@ -739,57 +734,41 @@ class TransposeConvStridedConverter : public OpRewritePattern { rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, convReshapeDims1Value); - // Effective pad = outPad + (paddedK - stride - 1) - (inPad * stride) - int64_t effPadTop = outPad[0] + (origWeightHeight - stride[0] - 1) - - inPadVals[0] * stride[0]; - int64_t effPadLeft = outPad[2] + (origWeightWidth - stride[1] - 1) - - inPadVals[2] * stride[1]; - - // When we shrink from the orignal size to kPrime by grouping stride phases, - // we discard some positions that existed in the conceptual upsampled view. - // The total span of the original field is kOrig -1, and the span - // represented after factoring is kPrime - 1. The difference is the values - // that have been lost int64_t kHPrime = restridedWeightTy.getDimSize(1); int64_t kWPrime = restridedWeightTy.getDimSize(2); - auto lost = [](int64_t Korig, int64_t kPrime, int64_t S) { - return (Korig - 1) - (kPrime - 1) * S; - }; - int64_t lostH = lost(origWeightHeight, kHPrime, stride[0]); - int64_t lostW = lost(origWeightWidth, kWPrime, stride[1]); - - // If stride factoring compresses a dimension to a single spatial position, - // i.e., kPrime == 1, then we dropped a ring of values around that position. - // The adjustment pattern depends on which dimension has asymmetric padding. - - // Height dimension compressed (kHPrime==1) - if (kHPrime == 1 && lostH > 0) { - int64_t adjustment = lostH / 2; - bool hasAsymmetricWidth = (weightPadding[4] != weightPadding[5]); - if (hasAsymmetricWidth) { - effPadTop -= adjustment; - effPadLeft += adjustment; - } else { - effPadLeft += adjustment; - } - } - // Width dimension compressed (kWPrime==1) - if (kWPrime == 1 && lostW > 0) { - effPadTop += lostW / 2; - } + // After factoring stride phases out of the filter channels, a contribution + // from gradient-output index h and original filter index k lands in the + // expanded stride-1 conv result at: + // ho_expanded = h * stride + k + pad_low * stride - stride * (kPrime - 1) + // where kPrime is the reduced spatial filter size used by the stride-1 + // convolution. The output of this op is dx[i] (with i = h*stride + k - + // pad_low) shifted by out_pad_low. So output position 0 corresponds to + // h*stride + k = pad_low - out_pad_low, which substituted above gives the + // low-side offset into the expanded conv result: + // offset = pad_low * (stride + 1) - stride * (kPrime - 1) - out_pad_low + // Positive offset means we crop the expanded result on the low side; + // negative offset means we pre-pad the result on the low side. + auto computeLowSideOffset = [](int64_t inPadLow, int64_t outPadLow, + int64_t strideVal, int64_t kPrime) { + return inPadLow * (strideVal + 1) - strideVal * (kPrime - 1) - outPadLow; + }; + int64_t offsetTop = + computeLowSideOffset(inPadVals[0], outPad[0], stride[0], kHPrime); + int64_t offsetLeft = + computeLowSideOffset(inPadVals[2], outPad[2], stride[1], kWPrime); int64_t resultSliceTop; int64_t resultSliceLeft; int64_t resultPadTop; int64_t resultPadLeft; - // Convert effective padding into slice (crop) and post-pad just like the - // prior logic but now using effPad*. + // Convert low-side offset into slice (crop) and post-pad. Both branches + // feed into the shared clamping + zero-result-overlap logic below. if (op->hasAttr("pad")) { - resultSliceTop = std::max(0, -effPadTop); - resultSliceLeft = std::max(0, -effPadLeft); - resultPadTop = std::max(0, effPadTop); - resultPadLeft = std::max(0, effPadLeft); + resultSliceTop = std::max(0, offsetTop); + resultSliceLeft = std::max(0, offsetLeft); + resultPadTop = std::max(0, -offsetTop); + resultPadLeft = std::max(0, -offsetLeft); } else { // Default to using legacy logic if input padding is not present resultSliceTop = std::max(0, -outPad[0]); @@ -798,39 +777,67 @@ class TransposeConvStridedConverter : public OpRewritePattern { resultPadLeft = std::max(0, outPad[2]); } - // Try to slice the targetted result size, cap to the convolutions width. - int64_t resultSliceHeight = - std::min(convReshapeDims1[1] - resultSliceTop, - resultTy.getDimSize(1) - resultPadTop); - int64_t resultSliceWidth = - std::min(convReshapeDims1[2] - resultSliceLeft, - resultTy.getDimSize(2) - resultPadLeft); - - llvm::SmallVector sliceBegin = {0, resultSliceTop, - resultSliceLeft, 0}; - llvm::SmallVector sliceSize(convReshapeDims1.begin(), - convReshapeDims1.end()); - sliceSize[1] = resultSliceHeight; - sliceSize[2] = resultSliceWidth; - - auto slice = CreateOpAndInferShape( - rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, - getTosaConstShape(rewriter, loc, sliceBegin), - getTosaConstShape(rewriter, loc, sliceSize)) - .getResult(); - - llvm::SmallVector resultPadding = {0, 0, 0, 0, 0, 0, 0, 0}; - resultPadding[2] = resultPadTop; - resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1]; - resultPadding[4] = resultPadLeft; - resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2]; - - Value resultPaddingVal = - getTosaConstShape(rewriter, op->getLoc(), resultPadding); - - Value resultPad = CreateOpAndInferShape( - rewriter, loc, UnrankedTensorType::get(resultETy), slice, - resultPaddingVal); + int64_t resultHeight = resultTy.getDimSize(1); + int64_t resultWidth = resultTy.getDimSize(2); + int64_t convExpandedHeight = convReshapeDims1[1]; + int64_t convExpandedWidth = convReshapeDims1[2]; + + // Extreme low-side padding/cropping can leave no overlap with the expanded + // convolution result. Keep the slice window valid and let padding fill the + // requested result extent. + resultPadTop = std::min(resultPadTop, resultHeight); + resultPadLeft = std::min(resultPadLeft, resultWidth); + resultSliceTop = std::min(resultSliceTop, convExpandedHeight); + resultSliceLeft = std::min(resultSliceLeft, convExpandedWidth); + + // Try to slice the targeted result size, cap to the convolution extent. + int64_t resultSliceHeight = std::min( + convExpandedHeight - resultSliceTop, resultHeight - resultPadTop); + int64_t resultSliceWidth = std::min( + convExpandedWidth - resultSliceLeft, resultWidth - resultPadLeft); + + // The clamping above guarantees both arguments to each `min` are + // non-negative, so the slice extents must be too. + assert(resultSliceHeight >= 0 && resultSliceWidth >= 0 && + "slice extents must be non-negative after clamping"); + + Value resultPad; + if (resultSliceHeight == 0 || resultSliceWidth == 0) { + // TOSA requires positive slice sizes. If the output window has no + // overlap with the expanded convolution result, materialize the pre-bias + // result directly as zeros. + resultPad = tosa::ConstOp::create( + rewriter, loc, resultTy, + DenseElementsAttr::get(resultTy, rewriter.getZeroAttr(resultETy))); + } else { + llvm::SmallVector sliceBegin = {0, resultSliceTop, + resultSliceLeft, 0}; + llvm::SmallVector sliceSize(convReshapeDims1.begin(), + convReshapeDims1.end()); + sliceSize[1] = resultSliceHeight; + sliceSize[2] = resultSliceWidth; + + auto slice = CreateOpAndInferShape( + rewriter, loc, UnrankedTensorType::get(resultETy), + conv2d, getTosaConstShape(rewriter, loc, sliceBegin), + getTosaConstShape(rewriter, loc, sliceSize)) + .getResult(); + + llvm::SmallVector resultPadding = {0, 0, 0, 0, 0, 0, 0, 0}; + resultPadding[2] = resultPadTop; + resultPadding[3] = resultHeight - resultPadTop - sliceSize[1]; + resultPadding[4] = resultPadLeft; + resultPadding[5] = resultWidth - resultPadLeft - sliceSize[2]; + assert(resultPadding[3] >= 0 && resultPadding[5] >= 0 && + "post-slice pad extents must be non-negative"); + + Value resultPaddingVal = + getTosaConstShape(rewriter, op->getLoc(), resultPadding); + + resultPad = CreateOpAndInferShape( + rewriter, loc, UnrankedTensorType::get(resultETy), slice, + resultPaddingVal); + } if (EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) { return failure(); diff --git a/mlir/lib/Dialect/Rock/Transforms/ConvToGemm.cpp b/mlir/lib/Dialect/Rock/Transforms/ConvToGemm.cpp index 2aa8be25ec3a..8978b637740d 100644 --- a/mlir/lib/Dialect/Rock/Transforms/ConvToGemm.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/ConvToGemm.cpp @@ -53,6 +53,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/LogicalResult.h" +#include "llvm/Support/MathExtras.h" #include #include @@ -830,9 +831,16 @@ FailureOr> backwardDataV4R1(ConvBwdDataOp op, iTilda[1] = (kernelId % product) / divisor; iTilda[0] = kernelId / product; } - for (size_t i = 0; i < convDims.fil.size(); i++) - iDotSlice.push_back(math_util::integer_divide_ceil( - convDims.fil[i] - iTilda[i], filTilda[i])); + // `kernelId` must come from `backwardDataKernelIds`, which filters out + // phases where `iTilda[i] >= convDims.fil[i]`. Without that filter, + // `divideCeil`'s unsigned-converting overload would wrap a negative + // numerator into a huge value here. + for (size_t i = 0; i < convDims.fil.size(); i++) { + assert(iTilda[i] < convDims.fil[i] && + "kernelId not pre-filtered by backwardDataKernelIds"); + iDotSlice.push_back( + llvm::divideCeil(convDims.fil[i] - iTilda[i], filTilda[i])); + } // backward data only, it's igemm v4r1 algo // c is input channels , k is output channels @@ -1125,11 +1133,11 @@ commonConvRewrite(T op, PatternRewriter &b, ConvolutionContext &ctx, auto strideDims = ctx.getStrideVal(); auto dilationDims = ctx.getDilationVal(); auto filterDims = ctx.getConvDims().fil; - auto numKernels = + auto kernelIds = rock::backwardDataKernelIds(strideDims, dilationDims, filterDims, /*usesV4R1=*/true); - for (size_t i = 0; i < numKernels.size(); i++) { - auto maybe = backwardDataV4R1(bwdDataOp, b, i, usesV4R1); + for (int64_t kernelId : kernelIds) { + auto maybe = backwardDataV4R1(bwdDataOp, b, kernelId, usesV4R1); if (failed(maybe)) return failure(); } diff --git a/mlir/lib/Dialect/Rock/utility/loweringUtils.cpp b/mlir/lib/Dialect/Rock/utility/loweringUtils.cpp index dda8d88d32cb..e39b63de71f4 100644 --- a/mlir/lib/Dialect/Rock/utility/loweringUtils.cpp +++ b/mlir/lib/Dialect/Rock/utility/loweringUtils.cpp @@ -28,6 +28,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/LogicalResult.h" @@ -120,6 +121,7 @@ mlir::rock::backwardDataKernelIds(ArrayRef strideDims, ArrayRef dilationDims, ArrayRef filterDims, bool usesV4R1) { assert(strideDims.size() == dilationDims.size()); + SmallVector gcdStrideDilations; for (const auto &[stride, dilation] : zip(strideDims, dilationDims)) gcdStrideDilations.push_back(math_util::gcd(stride, dilation)); @@ -139,7 +141,6 @@ mlir::rock::backwardDataKernelIds(ArrayRef strideDims, for (int64_t kernelId = 0; kernelId < product; ++kernelId) { // gemmK size is different for each GEMM SmallVector iTilda; - SmallVector iDotSlice; int64_t divisor = 1; iTilda.resize(filterDims.size()); switch (filterDims.size()) { @@ -154,14 +155,16 @@ mlir::rock::backwardDataKernelIds(ArrayRef strideDims, iTilda[1] = (kernelId % subproduct) / divisor; iTilda[0] = kernelId / subproduct; } - for (size_t i = 0; i < filterDims.size(); i++) - iDotSlice.push_back(math_util::integer_divide_ceil( - filterDims[i] - iTilda[i], filTilda[i])); - // gemmK must > 0, otherwise not need to run + // gemmK must be > 0, otherwise this kernel has no filter slice to run. int64_t gemmKproduct = 1; - for (int64_t ds : iDotSlice) - gemmKproduct *= ds; + for (size_t i = 0; i < filterDims.size(); i++) { + if (iTilda[i] >= filterDims[i]) { + gemmKproduct = 0; + break; + } + gemmKproduct *= llvm::divideCeil(filterDims[i] - iTilda[i], filTilda[i]); + } if (gemmKproduct > 0) { kernelIds.push_back(kernelId); } diff --git a/mlir/test/Conversion/RocmlirCustomTosaDecompose/rocmlir-custom-tosa-decompose.mlir b/mlir/test/Conversion/RocmlirCustomTosaDecompose/rocmlir-custom-tosa-decompose.mlir index 13bfdccd6e7c..18abb5d02c8e 100644 --- a/mlir/test/Conversion/RocmlirCustomTosaDecompose/rocmlir-custom-tosa-decompose.mlir +++ b/mlir/test/Conversion/RocmlirCustomTosaDecompose/rocmlir-custom-tosa-decompose.mlir @@ -108,4 +108,59 @@ func.func @bwd_data_conv1d(%arg0: tensor<64xf32>, %arg1: tensor<672xf32>, %arg2: } // ----- +// Regression test for the no-overlap zero-result branch in +// TransposeConvStridedConverter. With weightHeight=4 and stride_h=2 we get +// kHPrime=2, so the low-side height offset is +// inPadLow*(stride+1) - stride*(kPrime-1) - outPadLow = 2*3 - 2*1 - 0 = 4. +// convExpandedHeight = 4, so resultSliceTop saturates at convExpandedHeight +// and resultSliceHeight collapses to 0. The lowering must materialize the +// pre-bias result as a zero constant rather than emit a zero-extent +// tosa.slice (TOSA disallows empty slice extents). +// +// CHECK-LABEL: func @bwd_data_conv2d_empty_slice +// CHECK-NOT: tosa.custom +// CHECK-NOT: tosa.slice +// CHECK: %[[Z:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x2x4x1xf32>}> : () -> tensor<1x2x4x1xf32> +// CHECK: tosa.add %[[Z]], %{{.*}} : (tensor<1x2x4x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x2x4x1xf32> +func.func @bwd_data_conv2d_empty_slice(%grad_out: tensor<1x1x4x1xf32>, %weight: tensor<1x4x1x1xf32>) -> tensor<1x2x4x1xf32> { + %izp = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> + %wzp = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> + %bias = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.custom %grad_out, %weight, %bias, %izp, %wzp {acc_type = f32, dilation = array, domain_name = "rocmlir", group = 1 : i64, implementation_attrs = "", operator_name = "conv_bwd_data", out_pad = array, pad = array, stride = array} : (tensor<1x1x4x1xf32>, tensor<1x4x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x2x4x1xf32> + return %0 : tensor<1x2x4x1xf32> +} + +// ----- +// Locks in the new geometric low-side offset formula on the case where +// the legacy formula's `kHPrime == 1 && lostH > 0 && hasAsymmetricWidth` +// fixup was load-bearing. +// +// Shape: weight 1x3x2x1, stride [3, 3], pad [0,0,0,0], dilation [1,1]. +// `weightWidth (2) % stride[1] (3) != 0` so the modulo-stride pad makes +// `weightPadding[4] (0) != weightPadding[5] (1)`, which under the old +// code took the asymmetric-width branch with adjustment = lostH/2 = 1 +// (lostH = (3-1) - (1-1)*3 = 2). That branch produced +// effPadTop = -1 - 1 = -2, effPadLeft = -2 + 1 = -1 +// -> slice begin [0, 2, 1, 0] size [1, 4, 5, 1]. +// The new formula is independent of the asymmetric-width branch: +// offsetTop = inPadLow*(stride+1) - stride*(kHPrime-1) - outPadLow +// = 0*4 - 3*0 - 0 = 0 +// offsetLeft = 0 +// so the slice now begins at the origin of the 6x6 expanded conv +// result and takes the full 6 rows / 5 columns of output. +// +// CHECK-LABEL: func @bwd_data_conv2d_kprime_one_asym_weight_pad +// CHECK-NOT: tosa.custom +// CHECK: %[[SBEGIN:.*]] = tosa.const_shape {{.*}}values = dense<0> : tensor<4xindex>{{.*}} -> !tosa.shape<4> +// CHECK: %[[SSIZE:.*]] = tosa.const_shape {{.*}}values = dense<[1, 6, 5, 1]> : tensor<4xindex>{{.*}} -> !tosa.shape<4> +// CHECK: tosa.slice %{{.*}}, %[[SBEGIN]], %[[SSIZE]] : (tensor<1x6x6x1xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x6x5x1xf32> +func.func @bwd_data_conv2d_kprime_one_asym_weight_pad(%grad_out: tensor<1x2x2x1xf32>, %weight: tensor<1x3x2x1xf32>) -> tensor<1x6x5x1xf32> { + %izp = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> + %wzp = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> + %bias = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.custom %grad_out, %weight, %bias, %izp, %wzp {acc_type = f32, dilation = array, domain_name = "rocmlir", group = 1 : i64, implementation_attrs = "", operator_name = "conv_bwd_data", out_pad = array, pad = array, stride = array} : (tensor<1x2x2x1xf32>, tensor<1x3x2x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x6x5x1xf32> + return %0 : tensor<1x6x5x1xf32> +} + +// ----- diff --git a/mlir/test/Dialect/Rock/conv_to_gemm_bwd_data_empty_filter_slice.mlir b/mlir/test/Dialect/Rock/conv_to_gemm_bwd_data_empty_filter_slice.mlir new file mode 100644 index 000000000000..46493a3fc74a --- /dev/null +++ b/mlir/test/Dialect/Rock/conv_to_gemm_bwd_data_empty_filter_slice.mlir @@ -0,0 +1,45 @@ +// backwardDataKernelIds used to emit phantom kernel IDs for stride phases whose +// filter slice is empty. + +// The pre-fix implementation expressed the slice extent as +// `divideCeil(filterDims[i] - iTilda[i], filTilda[i])`. LLVM's default +// `divideCeil` is the unsigned-converting overload, so a negative numerator wrapped to a huge positive value -- making an empty phase look like real GEMM work and emitting a bogus per-phase rock.gemm. +// The fix detects iTilda[i] >= filterDims[i] and short-circuits gemmKproduct to 0 so the phase is correctly excluded. +// We pin the count of `rock.gemm` ops emitted by `rock-conv-to-gemm`, since that pass calls backwardDataKernelIds once per ConvBwdDataOp and emits one rock.gemm per returned kernel ID. + +// Rank-2 reproducer for the originally failing shape: filTilda = {2, 3}, +// filterDims = {2, 1}. The fix prunes iTilda[1] in {1, 2}, leaving the +// kernel-ID set {0, 3}. Without the fix iTilda[1] == 2 (kernel IDs 2, 5) +// would also slip through with a wrapped slice extent, giving 4 gemms. +// RUN: rocmlir-gen --operation=conv_bwd_data --arch %arch -v4r1 0 --kernel_id 0 \ +// RUN: --fil_layout=gkcyx --in_layout=ngchw --out_layout=ngkhw \ +// RUN: --batchsize=1 --groupsize=1 --in_channels=4 --out_channels=4 \ +// RUN: --in_h=4 --in_w=8 --fil_h=2 --fil_w=1 \ +// RUN: --dilation_h=1 --dilation_w=2 \ +// RUN: --conv_stride_h=2 --conv_stride_w=3 \ +// RUN: --padding_h=0 --padding_w=0 \ +// RUN: | rocmlir-driver -c --mlir-print-ir-after=rock-conv-to-gemm 2>&1 \ +// RUN: | FileCheck %s --check-prefix=RANK2 + +// Rank-3 reproducer exercising the rank-3 switch arm of +// backwardDataKernelIds: +// filTilda = {2, 2, 3}, filterDims = {2, 2, 1}, so only iTilda[2] +// (the dim with filterDims[i] < filTilda[i]) can index out of bounds. +// The fix prunes iTilda[2] in {1, 2}, leaving kernel IDs +// {0, 3, 6, 9} == 4 gemms. Without the fix iTilda[2] == 2 would also +// slip through (kernel IDs 2, 5, 8, 11), giving 8 gemms. +// RUN: rocmlir-gen --operation=conv_bwd_data --arch %arch -v4r1 0 --kernel_id 0 \ +// RUN: --fil_layout=gkc012 --in_layout=ngc012 --out_layout=ngk012 \ +// RUN: --batchsize=1 --groupsize=1 --in_channels=4 --out_channels=4 \ +// RUN: --in_d=4 --in_h=4 --in_w=8 --fil_d=2 --fil_h=2 --fil_w=1 \ +// RUN: --dilation_d=1 --dilation_h=1 --dilation_w=2 \ +// RUN: --conv_stride_d=2 --conv_stride_h=2 --conv_stride_w=3 \ +// RUN: --padding_d=0 --padding_h=0 --padding_w=0 \ +// RUN: | rocmlir-driver -c --mlir-print-ir-after=rock-conv-to-gemm 2>&1 \ +// RUN: | FileCheck %s --check-prefix=RANK3 + +// RANK2-COUNT-2: {{rock\.gemm[[:>:]]}} +// RANK2-NOT: {{rock\.gemm[[:>:]]}} + +// RANK3-COUNT-4: {{rock\.gemm[[:>:]]}} +// RANK3-NOT: {{rock\.gemm[[:>:]]}} diff --git a/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-empty-filter-slice.mlir b/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-empty-filter-slice.mlir new file mode 100644 index 000000000000..2f6b6e5aac37 --- /dev/null +++ b/mlir/test/fusion/pr-e2e/mixr-bwd-data-conv-empty-filter-slice.mlir @@ -0,0 +1,23 @@ +// RUN: rocmlir-gen -fut mlir_bwd_data_conv --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel -targets %arch | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_bwd_data_conv_wrapper --verifier clone -relDiff_threshold 0.00001 - | rocmlir-driver -host-pipeline mhal,runner -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s + +// We also want to check the original rocmlir-gen command that initially hit +// this issue is fixed. The MIGraphX -> Tosa -> Linalg CPU lowering currently does not +// support large group sizes (> 1), so we need to use the original rocmlir-gen command +// to check that the issue is fixed with g=4. +// RUN: rocmlir-gen --operation conv_bwd_data -t f16 --arch %arch -v4r1 0 --kernel_id 0 --fil_layout kyxc --in_layout nhwc --out_layout nhwk --batchsize 1 --in_channels 64 --in_h 32 --in_w 14 --out_channels 256 --fil_h 2 --fil_w 1 --dilation_h 1 --dilation_w 2 --conv_stride_h 2 --conv_stride_w 3 --padding_h 0 --padding_w 3 --groupsize 4 --perf_config= -pv | rocmlir-driver -c | mlir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=ROCKGEN + +// ROCKGEN: [1 1 1] + +module { + // CHECK: [1 1 1] + func.func @mlir_bwd_data_conv(%grad_out: !migraphx.shaped<1x16x16x7xf16, 1792x112x7x1>, %weights: !migraphx.shaped<16x16x2x1xf16, 32x2x1x1>) -> !migraphx.shaped<1x16x32x19xf16, 9728x608x19x1> attributes {rock.kernel} { + %res = migraphx.backwards_data_convolution %grad_out, %weights { + dilation = [1, 2], + group = 1 : i64, + padding = [0, 0, 0, 0], + padding_mode = 0 : i64, + stride = [2, 3] + } : <1x16x16x7xf16, 1792x112x7x1>, <16x16x2x1xf16, 32x2x1x1> -> <1x16x32x19xf16, 9728x608x19x1> + return %res : !migraphx.shaped<1x16x32x19xf16, 9728x608x19x1> + } +} From 1ba6b3e3f85844704c28d756ce5993f83563cc2e Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Wed, 6 May 2026 13:34:45 +0000 Subject: [PATCH 2/2] Attend to AI review comments --- .../Dialect/Rock/Transforms/ConvToGemm.cpp | 26 ++++-- ...v_to_gemm_bwd_data_empty_filter_slice.mlir | 91 +++++++++++++++---- 2 files changed, 92 insertions(+), 25 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/ConvToGemm.cpp b/mlir/lib/Dialect/Rock/Transforms/ConvToGemm.cpp index 8978b637740d..3f7547db4f00 100644 --- a/mlir/lib/Dialect/Rock/Transforms/ConvToGemm.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/ConvToGemm.cpp @@ -1123,19 +1123,33 @@ commonConvRewrite(T op, PatternRewriter &b, ConvolutionContext &ctx, if (ConvOpType::BwdData == convOpType) { auto bwdDataOp = cast(op); bool usesV4R1 = op->template getAttrOfType("usesV4R1").getValue(); + auto strideDims = ctx.getStrideVal(); + auto dilationDims = ctx.getDilationVal(); + auto filterDims = ctx.getConvDims().fil; + auto kernelIds = rock::backwardDataKernelIds(strideDims, dilationDims, + filterDims, /*usesV4R1=*/true); if (usesV4R1) { auto kernelId = bwdDataOp.getKernelIdAttr().getInt(); + // `backwardDataV4R1` requires that `iTilda[i] < filterDims[i]` for every + // dimension; otherwise its `llvm::divideCeil` (unsigned) wraps a + // negative numerator into a huge slice extent. Reject ids that do not + // correspond to a real, non-empty stride phase. + if (!llvm::is_contained(kernelIds, kernelId)) { + InFlightDiagnostic diag = + bwdDataOp.emitOpError() + << "v4r1 kernel id " << kernelId + << " has an empty filter slice and cannot be lowered; valid v4r1 " + "kernel ids for this convolution shape are {"; + for (auto [i, id] : llvm::enumerate(kernelIds)) + diag << (i == 0 ? "" : ", ") << id; + diag << "}"; + return failure(); + } return backwardDataV4R1(bwdDataOp, b, kernelId, usesV4R1); } else { // For the cases where the V4R1 algorithm requires more than one kernel, // i.e., stride != dilation, we want to create multiple GEMMs in a // single kernel - auto strideDims = ctx.getStrideVal(); - auto dilationDims = ctx.getDilationVal(); - auto filterDims = ctx.getConvDims().fil; - auto kernelIds = - rock::backwardDataKernelIds(strideDims, dilationDims, filterDims, - /*usesV4R1=*/true); for (int64_t kernelId : kernelIds) { auto maybe = backwardDataV4R1(bwdDataOp, b, kernelId, usesV4R1); if (failed(maybe)) diff --git a/mlir/test/Dialect/Rock/conv_to_gemm_bwd_data_empty_filter_slice.mlir b/mlir/test/Dialect/Rock/conv_to_gemm_bwd_data_empty_filter_slice.mlir index 46493a3fc74a..74bc57d050d0 100644 --- a/mlir/test/Dialect/Rock/conv_to_gemm_bwd_data_empty_filter_slice.mlir +++ b/mlir/test/Dialect/Rock/conv_to_gemm_bwd_data_empty_filter_slice.mlir @@ -1,16 +1,30 @@ -// backwardDataKernelIds used to emit phantom kernel IDs for stride phases whose -// filter slice is empty. - -// The pre-fix implementation expressed the slice extent as -// `divideCeil(filterDims[i] - iTilda[i], filTilda[i])`. LLVM's default -// `divideCeil` is the unsigned-converting overload, so a negative numerator wrapped to a huge positive value -- making an empty phase look like real GEMM work and emitting a bogus per-phase rock.gemm. -// The fix detects iTilda[i] >= filterDims[i] and short-circuits gemmKproduct to 0 so the phase is correctly excluded. -// We pin the count of `rock.gemm` ops emitted by `rock-conv-to-gemm`, since that pass calls backwardDataKernelIds once per ConvBwdDataOp and emits one rock.gemm per returned kernel ID. +// `backwardDataKernelIds` enumerates the per-stride phase GEMMs that make up +// a single strided backwards-data conv. A phase whose filter slice is empty +// (i.e. some `iTilda[i] >= filterDims[i]`) must be excluded; otherwise we +// emit a phantom GEMM with a degenerate or wrap-around K dimension. +// +// Two cooperating fixes guarantee that on this branch: +// 1. `backwardDataKernelIds` short-circuits `gemmKproduct` to 0 the moment +// it sees `iTilda[i] >= filterDims[i]`. This is needed because the +// slice extent is now computed with `llvm::divideCeil` (the unsigned +// overload) -- without the explicit guard a negative numerator would +// wrap to a huge positive value and the empty phase would be retained. +// The previous rocMLIR code relied on signed `math_util::integer_divide_ceil` +// to produce 0 for those numerators, which masked the missing guard. +// 2. The `usesV4R1=false` arm of `commonConvRewrite` in `ConvToGemm.cpp` +// now iterates the actual filtered ids (`for (int64_t kernelId : kernelIds)`) +// instead of the loop index (`for (size_t i = 0; i < kernelIds.size(); ++i)`). +// The pre-fix loop passed `i` as the kernel id, so e.g. ids `{0, 3}` were +// lowered as `{0, 1}`, silently emitting a degenerate GEMM for id 1. +// +// We pin the lowered IR by walking each emitted `rock.gemm` in order with a +// `CHECK-NOT: rock.gemm` between consecutive ids: the in-order matches pin the +// `kernelId` set (catching the index/id mismatch), and the interleaved +// CHECK-NOTs forbid any extra phantom-phase gemm anywhere in between or after. // Rank-2 reproducer for the originally failing shape: filTilda = {2, 3}, -// filterDims = {2, 1}. The fix prunes iTilda[1] in {1, 2}, leaving the -// kernel-ID set {0, 3}. Without the fix iTilda[1] == 2 (kernel IDs 2, 5) -// would also slip through with a wrapped slice extent, giving 4 gemms. +// filterDims = {2, 1}. The valid kernel ids are {0, 3}; ids {1, 2, 4, 5} +// all have `iTilda[i] >= filterDims[i]` for some i and must be skipped. // RUN: rocmlir-gen --operation=conv_bwd_data --arch %arch -v4r1 0 --kernel_id 0 \ // RUN: --fil_layout=gkcyx --in_layout=ngchw --out_layout=ngkhw \ // RUN: --batchsize=1 --groupsize=1 --in_channels=4 --out_channels=4 \ @@ -22,12 +36,16 @@ // RUN: | FileCheck %s --check-prefix=RANK2 // Rank-3 reproducer exercising the rank-3 switch arm of -// backwardDataKernelIds: -// filTilda = {2, 2, 3}, filterDims = {2, 2, 1}, so only iTilda[2] -// (the dim with filterDims[i] < filTilda[i]) can index out of bounds. -// The fix prunes iTilda[2] in {1, 2}, leaving kernel IDs -// {0, 3, 6, 9} == 4 gemms. Without the fix iTilda[2] == 2 would also -// slip through (kernel IDs 2, 5, 8, 11), giving 8 gemms. +// `backwardDataKernelIds`. `rocmlir-gen` packs the spatial dims in +// `[h, w, d]` order (depth is appended last in `parseConvDims`), so for +// `--fil_layout=gkc012` / `--in_layout=ngc012` / `--out_layout=ngk012`: +// filterDims = [fil_h, fil_w, fil_d] = [2, 1, 2] +// strideDims = [stride_h, stride_w, stride_d] = [2, 3, 2] +// dilationDims = [dilation_h, dilation_w, dilation_d] = [1, 2, 1] +// which gives filTilda = [2, 3, 2] (12 candidate ids). Walking ids 0..11 +// and rejecting any phase with `iTilda[i] >= filterDims[i]` in any dim +// leaves valid ids `{0, 1, 6, 7}`; the other 8 ids all have +// `iTilda[1] >= 1` (since fil_w = 1 makes any non-zero w phase empty). // RUN: rocmlir-gen --operation=conv_bwd_data --arch %arch -v4r1 0 --kernel_id 0 \ // RUN: --fil_layout=gkc012 --in_layout=ngc012 --out_layout=ngk012 \ // RUN: --batchsize=1 --groupsize=1 --in_channels=4 --out_channels=4 \ @@ -38,8 +56,43 @@ // RUN: | rocmlir-driver -c --mlir-print-ir-after=rock-conv-to-gemm 2>&1 \ // RUN: | FileCheck %s --check-prefix=RANK3 -// RANK2-COUNT-2: {{rock\.gemm[[:>:]]}} +// Pin each emitted rock.gemm's kernelId in order, with CHECK-NOT in +// between to forbid extra gemms. This composition simultaneously pins: +// - the gemm count (extra gemms would trip a CHECK-NOT), +// - the exact kernelId set (missing or extra ids fail the next CHECK), and +// - the index/id mismatch fix (the pre-fix loop emitted kernelIds 0 and 1 +// instead of {0, 3}, so the second CHECK for kernelId = 3 would fail). +// RANK2: {{rock\.gemm[[:>:]]}}{{.*}}kernelId = 0 : index +// RANK2-NOT: {{rock\.gemm[[:>:]]}} +// RANK2: {{rock\.gemm[[:>:]]}}{{.*}}kernelId = 3 : index // RANK2-NOT: {{rock\.gemm[[:>:]]}} -// RANK3-COUNT-4: {{rock\.gemm[[:>:]]}} +// RANK3: {{rock\.gemm[[:>:]]}}{{.*}}kernelId = 0 : index +// RANK3-NOT: {{rock\.gemm[[:>:]]}} +// RANK3: {{rock\.gemm[[:>:]]}}{{.*}}kernelId = 1 : index +// RANK3-NOT: {{rock\.gemm[[:>:]]}} +// RANK3: {{rock\.gemm[[:>:]]}}{{.*}}kernelId = 6 : index // RANK3-NOT: {{rock\.gemm[[:>:]]}} +// RANK3: {{rock\.gemm[[:>:]]}}{{.*}}kernelId = 7 : index +// RANK3-NOT: {{rock\.gemm[[:>:]]}} + +// V4R1=true validation: when usesV4R1 is true, `commonConvRewrite` does not +// iterate `backwardDataKernelIds`. Instead it dispatches a single GEMM for +// the user-supplied `kernelId`. That id must still land on a non-empty +// stride phase; otherwise `backwardDataV4R1`'s `llvm::divideCeil` (unsigned) +// would wrap a negative numerator into a multi-exabyte slice extent. +// `commonConvRewrite` validates the id against `backwardDataKernelIds(...)` +// and emits an op error before reaching the slice math. +// Same rank-2 config as above (valid ids {0, 3}); --kernel_id 2 is empty. +// RUN: rocmlir-gen --operation=conv_bwd_data --arch %arch -v4r1 1 --kernel_id 2 \ +// RUN: --fil_layout=gkcyx --in_layout=ngchw --out_layout=ngkhw \ +// RUN: --batchsize=1 --groupsize=1 --in_channels=4 --out_channels=4 \ +// RUN: --in_h=4 --in_w=8 --fil_h=2 --fil_w=1 \ +// RUN: --dilation_h=1 --dilation_w=2 \ +// RUN: --conv_stride_h=2 --conv_stride_w=3 \ +// RUN: --padding_h=0 --padding_w=0 \ +// RUN: | not rocmlir-driver -c 2>&1 \ +// RUN: | FileCheck %s --check-prefix=V4R1_EMPTY_PHASE + +// V4R1_EMPTY_PHASE: error: 'rock.conv_bwd_data' op v4r1 kernel id 2 has an empty filter slice and cannot be lowered +// V4R1_EMPTY_PHASE-SAME: valid v4r1 kernel ids for this convolution shape are {0, 3}