Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ class TransposeConvStridedConverter : public OpRewritePattern<tosa::CustomOp> {
ShapedType inputTy = cast<ShapedType>(input.getType());
ShapedType weightTy = cast<ShapedType>(weight.getType());
ShapedType biasTy = cast<ShapedType>(bias.getType());
ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
auto resultTy = cast<RankedTensorType>(op->getResult(0).getType());

Type inputETy = inputTy.getElementType();
Type weightETy = weightTy.getElementType();
Expand Down Expand Up @@ -596,11 +596,6 @@ class TransposeConvStridedConverter : public OpRewritePattern<tosa::CustomOp> {
dilationVals = {1, 1};
}

// We want to capture the height and width values after dilation expansion,
// but before padding is added later on.
int64_t origWeightHeight = weightHeight;
int64_t origWeightWidth = weightWidth;

// Pad the weight so that it is modulo of the striding.
llvm::SmallVector<int64_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
weightPadding[3] =
Expand Down Expand Up @@ -739,57 +734,41 @@ class TransposeConvStridedConverter : public OpRewritePattern<tosa::CustomOp> {
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
convReshapeDims1Value);

// Effective pad = outPad + (paddedK - stride - 1) - (inPad * stride)
int64_t effPadTop = outPad[0] + (origWeightHeight - stride[0] - 1) -
inPadVals[0] * stride[0];
int64_t effPadLeft = outPad[2] + (origWeightWidth - stride[1] - 1) -
inPadVals[2] * stride[1];

// When we shrink from the orignal size to kPrime by grouping stride phases,
// we discard some positions that existed in the conceptual upsampled view.
// The total span of the original field is kOrig -1, and the span
// represented after factoring is kPrime - 1. The difference is the values
// that have been lost
int64_t kHPrime = restridedWeightTy.getDimSize(1);
int64_t kWPrime = restridedWeightTy.getDimSize(2);
auto lost = [](int64_t Korig, int64_t kPrime, int64_t S) {
return (Korig - 1) - (kPrime - 1) * S;
};
int64_t lostH = lost(origWeightHeight, kHPrime, stride[0]);
int64_t lostW = lost(origWeightWidth, kWPrime, stride[1]);

// If stride factoring compresses a dimension to a single spatial position,
// i.e., kPrime == 1, then we dropped a ring of values around that position.
// The adjustment pattern depends on which dimension has asymmetric padding.

// Height dimension compressed (kHPrime==1)
if (kHPrime == 1 && lostH > 0) {
int64_t adjustment = lostH / 2;
bool hasAsymmetricWidth = (weightPadding[4] != weightPadding[5]);
if (hasAsymmetricWidth) {
effPadTop -= adjustment;
effPadLeft += adjustment;
} else {
effPadLeft += adjustment;
}
}

// Width dimension compressed (kWPrime==1)
if (kWPrime == 1 && lostW > 0) {
effPadTop += lostW / 2;
}
// After factoring stride phases out of the filter channels, a contribution
// from gradient-output index h and original filter index k lands in the
// expanded stride-1 conv result at:
// ho_expanded = h * stride + k + pad_low * stride - stride * (kPrime - 1)
// where kPrime is the reduced spatial filter size used by the stride-1
// convolution. The output of this op is dx[i] (with i = h*stride + k -
// pad_low) shifted by out_pad_low. So output position 0 corresponds to
// h*stride + k = pad_low - out_pad_low, which substituted above gives the
// low-side offset into the expanded conv result:
// offset = pad_low * (stride + 1) - stride * (kPrime - 1) - out_pad_low
// Positive offset means we crop the expanded result on the low side;
// negative offset means we pre-pad the result on the low side.
auto computeLowSideOffset = [](int64_t inPadLow, int64_t outPadLow,
int64_t strideVal, int64_t kPrime) {
return inPadLow * (strideVal + 1) - strideVal * (kPrime - 1) - outPadLow;
};
int64_t offsetTop =
computeLowSideOffset(inPadVals[0], outPad[0], stride[0], kHPrime);
int64_t offsetLeft =
computeLowSideOffset(inPadVals[2], outPad[2], stride[1], kWPrime);

int64_t resultSliceTop;
int64_t resultSliceLeft;
int64_t resultPadTop;
int64_t resultPadLeft;
// Convert effective padding into slice (crop) and post-pad just like the
// prior logic but now using effPad*.
// Convert low-side offset into slice (crop) and post-pad. Both branches
// feed into the shared clamping + zero-result-overlap logic below.
if (op->hasAttr("pad")) {
resultSliceTop = std::max<int64_t>(0, -effPadTop);
resultSliceLeft = std::max<int64_t>(0, -effPadLeft);
resultPadTop = std::max<int64_t>(0, effPadTop);
resultPadLeft = std::max<int64_t>(0, effPadLeft);
resultSliceTop = std::max<int64_t>(0, offsetTop);
resultSliceLeft = std::max<int64_t>(0, offsetLeft);
resultPadTop = std::max<int64_t>(0, -offsetTop);
resultPadLeft = std::max<int64_t>(0, -offsetLeft);
} else {
// Default to using legacy logic if input padding is not present
resultSliceTop = std::max<int64_t>(0, -outPad[0]);
Expand All @@ -798,39 +777,67 @@ class TransposeConvStridedConverter : public OpRewritePattern<tosa::CustomOp> {
resultPadLeft = std::max<int64_t>(0, outPad[2]);
}

// Try to slice the targetted result size, cap to the convolutions width.
int64_t resultSliceHeight =
std::min<int64_t>(convReshapeDims1[1] - resultSliceTop,
resultTy.getDimSize(1) - resultPadTop);
int64_t resultSliceWidth =
std::min<int64_t>(convReshapeDims1[2] - resultSliceLeft,
resultTy.getDimSize(2) - resultPadLeft);

llvm::SmallVector<int64_t, 4> sliceBegin = {0, resultSliceTop,
resultSliceLeft, 0};
llvm::SmallVector<int64_t, 4> sliceSize(convReshapeDims1.begin(),
convReshapeDims1.end());
sliceSize[1] = resultSliceHeight;
sliceSize[2] = resultSliceWidth;

auto slice = CreateOpAndInferShape<tosa::SliceOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
getTosaConstShape(rewriter, loc, sliceBegin),
getTosaConstShape(rewriter, loc, sliceSize))
.getResult();

llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
resultPadding[2] = resultPadTop;
resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
resultPadding[4] = resultPadLeft;
resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];

Value resultPaddingVal =
getTosaConstShape(rewriter, op->getLoc(), resultPadding);

Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), slice,
resultPaddingVal);
int64_t resultHeight = resultTy.getDimSize(1);
int64_t resultWidth = resultTy.getDimSize(2);
int64_t convExpandedHeight = convReshapeDims1[1];
int64_t convExpandedWidth = convReshapeDims1[2];

// Extreme low-side padding/cropping can leave no overlap with the expanded
// convolution result. Keep the slice window valid and let padding fill the
// requested result extent.
resultPadTop = std::min(resultPadTop, resultHeight);
resultPadLeft = std::min(resultPadLeft, resultWidth);
resultSliceTop = std::min(resultSliceTop, convExpandedHeight);
resultSliceLeft = std::min(resultSliceLeft, convExpandedWidth);

// Try to slice the targeted result size, cap to the convolution extent.
int64_t resultSliceHeight = std::min<int64_t>(
convExpandedHeight - resultSliceTop, resultHeight - resultPadTop);
int64_t resultSliceWidth = std::min<int64_t>(
convExpandedWidth - resultSliceLeft, resultWidth - resultPadLeft);

// The clamping above guarantees both arguments to each `min` are
// non-negative, so the slice extents must be too.
assert(resultSliceHeight >= 0 && resultSliceWidth >= 0 &&
"slice extents must be non-negative after clamping");

Value resultPad;
if (resultSliceHeight == 0 || resultSliceWidth == 0) {
// TOSA requires positive slice sizes. If the output window has no
// overlap with the expanded convolution result, materialize the pre-bias
// result directly as zeros.
resultPad = tosa::ConstOp::create(
rewriter, loc, resultTy,
DenseElementsAttr::get(resultTy, rewriter.getZeroAttr(resultETy)));
} else {
llvm::SmallVector<int64_t, 4> sliceBegin = {0, resultSliceTop,
resultSliceLeft, 0};
llvm::SmallVector<int64_t, 4> sliceSize(convReshapeDims1.begin(),
convReshapeDims1.end());
sliceSize[1] = resultSliceHeight;
sliceSize[2] = resultSliceWidth;

auto slice = CreateOpAndInferShape<tosa::SliceOp>(
rewriter, loc, UnrankedTensorType::get(resultETy),
conv2d, getTosaConstShape(rewriter, loc, sliceBegin),
getTosaConstShape(rewriter, loc, sliceSize))
.getResult();

llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
resultPadding[2] = resultPadTop;
resultPadding[3] = resultHeight - resultPadTop - sliceSize[1];
resultPadding[4] = resultPadLeft;
resultPadding[5] = resultWidth - resultPadLeft - sliceSize[2];
assert(resultPadding[3] >= 0 && resultPadding[5] >= 0 &&
"post-slice pad extents must be non-negative");

Value resultPaddingVal =
getTosaConstShape(rewriter, op->getLoc(), resultPadding);

resultPad = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), slice,
resultPaddingVal);
}

if (EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) {
return failure();
Expand Down
44 changes: 33 additions & 11 deletions mlir/lib/Dialect/Rock/Transforms/ConvToGemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include <iterator>
#include <tuple>

Expand Down Expand Up @@ -830,9 +831,16 @@ FailureOr<std::tuple<Value, Value, Value>> backwardDataV4R1(ConvBwdDataOp op,
iTilda[1] = (kernelId % product) / divisor;
iTilda[0] = kernelId / product;
}
for (size_t i = 0; i < convDims.fil.size(); i++)
iDotSlice.push_back(math_util::integer_divide_ceil(
convDims.fil[i] - iTilda[i], filTilda[i]));
// `kernelId` must come from `backwardDataKernelIds`, which filters out
// phases where `iTilda[i] >= convDims.fil[i]`. Without that filter,
// `divideCeil`'s unsigned-converting overload would wrap a negative
// numerator into a huge value here.
for (size_t i = 0; i < convDims.fil.size(); i++) {
assert(iTilda[i] < convDims.fil[i] &&
"kernelId not pre-filtered by backwardDataKernelIds");
iDotSlice.push_back(
llvm::divideCeil(convDims.fil[i] - iTilda[i], filTilda[i]));
Comment thread
justinrosner marked this conversation as resolved.
}

// backward data only, it's igemm v4r1 algo
// c is input channels , k is output channels
Expand Down Expand Up @@ -1115,21 +1123,35 @@ commonConvRewrite(T op, PatternRewriter &b, ConvolutionContext &ctx,
if (ConvOpType::BwdData == convOpType) {
auto bwdDataOp = cast<ConvBwdDataOp>(op);
bool usesV4R1 = op->template getAttrOfType<BoolAttr>("usesV4R1").getValue();
auto strideDims = ctx.getStrideVal();
auto dilationDims = ctx.getDilationVal();
auto filterDims = ctx.getConvDims().fil;
auto kernelIds = rock::backwardDataKernelIds(strideDims, dilationDims,
filterDims, /*usesV4R1=*/true);
if (usesV4R1) {
auto kernelId = bwdDataOp.getKernelIdAttr().getInt();
// `backwardDataV4R1` requires that `iTilda[i] < filterDims[i]` for every
// dimension; otherwise its `llvm::divideCeil` (unsigned) wraps a
// negative numerator into a huge slice extent. Reject ids that do not
// correspond to a real, non-empty stride phase.
if (!llvm::is_contained(kernelIds, kernelId)) {
InFlightDiagnostic diag =
bwdDataOp.emitOpError()
<< "v4r1 kernel id " << kernelId
<< " has an empty filter slice and cannot be lowered; valid v4r1 "
"kernel ids for this convolution shape are {";
for (auto [i, id] : llvm::enumerate(kernelIds))
diag << (i == 0 ? "" : ", ") << id;
diag << "}";
return failure();
}
return backwardDataV4R1(bwdDataOp, b, kernelId, usesV4R1);
} else {
// For the cases where the V4R1 algorithm requires more than one kernel,
// i.e., stride != dilation, we want to create multiple GEMMs in a
// single kernel
auto strideDims = ctx.getStrideVal();
auto dilationDims = ctx.getDilationVal();
auto filterDims = ctx.getConvDims().fil;
auto numKernels =
rock::backwardDataKernelIds(strideDims, dilationDims, filterDims,
/*usesV4R1=*/true);
for (size_t i = 0; i < numKernels.size(); i++) {
auto maybe = backwardDataV4R1(bwdDataOp, b, i, usesV4R1);
for (int64_t kernelId : kernelIds) {
auto maybe = backwardDataV4R1(bwdDataOp, b, kernelId, usesV4R1);
if (failed(maybe))
return failure();
}
Expand Down
17 changes: 10 additions & 7 deletions mlir/lib/Dialect/Rock/utility/loweringUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"

#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
Expand Down Expand Up @@ -120,6 +121,7 @@ mlir::rock::backwardDataKernelIds(ArrayRef<int64_t> strideDims,
ArrayRef<int64_t> dilationDims,
ArrayRef<int64_t> filterDims, bool usesV4R1) {
assert(strideDims.size() == dilationDims.size());

SmallVector<int64_t, 5> gcdStrideDilations;
for (const auto &[stride, dilation] : zip(strideDims, dilationDims))
gcdStrideDilations.push_back(math_util::gcd(stride, dilation));
Expand All @@ -139,7 +141,6 @@ mlir::rock::backwardDataKernelIds(ArrayRef<int64_t> strideDims,
for (int64_t kernelId = 0; kernelId < product; ++kernelId) {
// gemmK size is different for each GEMM
SmallVector<int64_t, 3> iTilda;
SmallVector<int64_t, 3> iDotSlice;
int64_t divisor = 1;
iTilda.resize(filterDims.size());
switch (filterDims.size()) {
Expand All @@ -154,14 +155,16 @@ mlir::rock::backwardDataKernelIds(ArrayRef<int64_t> strideDims,
iTilda[1] = (kernelId % subproduct) / divisor;
iTilda[0] = kernelId / subproduct;
}
for (size_t i = 0; i < filterDims.size(); i++)
iDotSlice.push_back(math_util::integer_divide_ceil(
filterDims[i] - iTilda[i], filTilda[i]));

// gemmK must > 0, otherwise not need to run
// gemmK must be > 0, otherwise this kernel has no filter slice to run.
int64_t gemmKproduct = 1;
for (int64_t ds : iDotSlice)
gemmKproduct *= ds;
for (size_t i = 0; i < filterDims.size(); i++) {
if (iTilda[i] >= filterDims[i]) {
gemmKproduct = 0;
break;
}
gemmKproduct *= llvm::divideCeil(filterDims[i] - iTilda[i], filTilda[i]);
}
if (gemmKproduct > 0) {
kernelIds.push_back(kernelId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,59 @@ func.func @bwd_data_conv1d(%arg0: tensor<64xf32>, %arg1: tensor<672xf32>, %arg2:
}

// -----
// Regression test for the no-overlap zero-result branch in
// TransposeConvStridedConverter. With weightHeight=4 and stride_h=2 we get
// kHPrime=2, so the low-side height offset is
// inPadLow*(stride+1) - stride*(kPrime-1) - outPadLow = 2*3 - 2*1 - 0 = 4.
// convExpandedHeight = 4, so resultSliceTop saturates at convExpandedHeight
// and resultSliceHeight collapses to 0. The lowering must materialize the
// pre-bias result as a zero constant rather than emit a zero-extent
// tosa.slice (TOSA disallows empty slice extents).
//
// CHECK-LABEL: func @bwd_data_conv2d_empty_slice
// CHECK-NOT: tosa.custom
// CHECK-NOT: tosa.slice
// CHECK: %[[Z:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x2x4x1xf32>}> : () -> tensor<1x2x4x1xf32>
// CHECK: tosa.add %[[Z]], %{{.*}} : (tensor<1x2x4x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x2x4x1xf32>
func.func @bwd_data_conv2d_empty_slice(%grad_out: tensor<1x1x4x1xf32>, %weight: tensor<1x4x1x1xf32>) -> tensor<1x2x4x1xf32> {
%izp = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
%wzp = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
%bias = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
%0 = tosa.custom %grad_out, %weight, %bias, %izp, %wzp {acc_type = f32, dilation = array<i64: 1, 1>, domain_name = "rocmlir", group = 1 : i64, implementation_attrs = "", operator_name = "conv_bwd_data", out_pad = array<i64: 0, 0, 0, 0>, pad = array<i64: 2, 0, 0, 0>, stride = array<i64: 2, 1>} : (tensor<1x1x4x1xf32>, tensor<1x4x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x2x4x1xf32>
return %0 : tensor<1x2x4x1xf32>
}

// -----
// Locks in the new geometric low-side offset formula on the case where
// the legacy formula's `kHPrime == 1 && lostH > 0 && hasAsymmetricWidth`
// fixup was load-bearing.
//
// Shape: weight 1x3x2x1, stride [3, 3], pad [0,0,0,0], dilation [1,1].
// `weightWidth (2) % stride[1] (3) != 0` so the modulo-stride pad makes
// `weightPadding[4] (0) != weightPadding[5] (1)`, which under the old
// code took the asymmetric-width branch with adjustment = lostH/2 = 1
// (lostH = (3-1) - (1-1)*3 = 2). That branch produced
// effPadTop = -1 - 1 = -2, effPadLeft = -2 + 1 = -1
// -> slice begin [0, 2, 1, 0] size [1, 4, 5, 1].
// The new formula is independent of the asymmetric-width branch:
// offsetTop = inPadLow*(stride+1) - stride*(kHPrime-1) - outPadLow
// = 0*4 - 3*0 - 0 = 0
// offsetLeft = 0
// so the slice now begins at the origin of the 6x6 expanded conv
// result and takes the full 6 rows / 5 columns of output.
//
// CHECK-LABEL: func @bwd_data_conv2d_kprime_one_asym_weight_pad
// CHECK-NOT: tosa.custom
// CHECK: %[[SBEGIN:.*]] = tosa.const_shape {{.*}}values = dense<0> : tensor<4xindex>{{.*}} -> !tosa.shape<4>
// CHECK: %[[SSIZE:.*]] = tosa.const_shape {{.*}}values = dense<[1, 6, 5, 1]> : tensor<4xindex>{{.*}} -> !tosa.shape<4>
// CHECK: tosa.slice %{{.*}}, %[[SBEGIN]], %[[SSIZE]] : (tensor<1x6x6x1xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x6x5x1xf32>
func.func @bwd_data_conv2d_kprime_one_asym_weight_pad(%grad_out: tensor<1x2x2x1xf32>, %weight: tensor<1x3x2x1xf32>) -> tensor<1x6x5x1xf32> {
%izp = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
%wzp = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
%bias = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
%0 = tosa.custom %grad_out, %weight, %bias, %izp, %wzp {acc_type = f32, dilation = array<i64: 1, 1>, domain_name = "rocmlir", group = 1 : i64, implementation_attrs = "", operator_name = "conv_bwd_data", out_pad = array<i64: 0, 0, 0, 0>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 3, 3>} : (tensor<1x2x2x1xf32>, tensor<1x3x2x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x6x5x1xf32>
return %0 : tensor<1x6x5x1xf32>
}

// -----

Loading
Loading