|
14 | 14 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
15 | 15 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
16 | 16 | #include "mlir/Dialect/Rock/IR/Rock.h" |
| 17 | +#include "mlir/Dialect/Rock/IR/TransformMapBuilder.h" |
17 | 18 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
18 | 19 | #include "mlir/IR/AffineExpr.h" |
19 | 20 | #include "mlir/IR/PatternMatch.h" |
@@ -189,7 +190,7 @@ LogicalResult MatmulConverter<LinalgMatOp>::matchAndRewrite( |
189 | 190 | } |
190 | 191 | MatmulContext context = maybeContext.value(); |
191 | 192 |
|
192 | | - // TODO: handle split K attributes as well |
| 193 | + // TODO: see (AIROCMLIR-696) |
193 | 194 | // TODO: handle broadcasting for matrix A and B |
194 | 195 | RankedTensorType outputType = |
195 | 196 | cast<RankedTensorType>(op.getOutputs()[0].getType()); |
@@ -259,9 +260,273 @@ LogicalResult ExpandStrideConverter::matchAndRewrite( |
259 | 260 | return success(); |
260 | 261 | } |
261 | 262 |
|
| 263 | +//===----------------------------------------------------------------------===// |
| 264 | +// ConvLinalgConverter: linalg.generic (conv) -> rock.conv |
| 265 | +//===----------------------------------------------------------------------===// |
| 266 | +namespace { |
| 267 | +struct ConvFields { |
| 268 | + rock::LinalgConvType type; |
| 269 | + int64_t spatialDim; |
| 270 | + ArrayAttr padding, stride, dilation; |
| 271 | + StringAttr perfConfig; |
| 272 | +}; |
| 273 | +} // namespace |
| 274 | + |
| 275 | +static int64_t getSpatialDim(rock::LinalgConvType type) { |
| 276 | + switch (type) { |
| 277 | + case rock::LinalgConvType::Conv1dNgchGkch: |
| 278 | + return 1; |
| 279 | + case rock::LinalgConvType::Conv2dNgchwGkchw: |
| 280 | + return 2; |
| 281 | + case rock::LinalgConvType::Conv3dNgchwdGkchwd: |
| 282 | + return 3; |
| 283 | + } |
| 284 | + llvm_unreachable("unknown LinalgConvType"); |
| 285 | +} |
| 286 | + |
| 287 | +/// Set filter_layout, input_layout, and output_layout on a rock.conv op. |
| 288 | +/// Layouts match the linalg convention: GKC*, NGC*, NGK*. |
| 289 | +static void setConvLayoutAttrs(OpBuilder &builder, rock::ConvOp cop, |
| 290 | + int64_t spatialDim) { |
| 291 | + auto *ctx = builder.getContext(); |
| 292 | + auto setLayout = [&](StringRef attrName, ArrayRef<StringRef> prefix, |
| 293 | + StringRef suffix) { |
| 294 | + SmallVector<Attribute> layout; |
| 295 | + for (StringRef dim : prefix) |
| 296 | + layout.push_back(StringAttr::get(ctx, dim)); |
| 297 | + for (int64_t i = 0; i < spatialDim; ++i) |
| 298 | + layout.push_back(StringAttr::get(ctx, Twine(i) + suffix)); |
| 299 | + cop->setAttr(attrName, builder.getArrayAttr(layout)); |
| 300 | + }; |
| 301 | + |
| 302 | + // The input layout in the operand of linalg.generic is NGC*, and |
| 303 | + // the filter layout is GKC*. We have to transfer these attribute |
| 304 | + // because later on in the pass, ConvToGemm expect them to be attached. |
| 305 | + setLayout("filter_layout", {"g", "k", "c"}, ""); |
| 306 | + setLayout("input_layout", {"ni", "gi", "ci"}, "i"); |
| 307 | + setLayout("output_layout", {"no", "go", "ko"}, "o"); |
| 308 | +} |
| 309 | + |
| 310 | +/// Remove the tensor.pad + tensor.expand_shape pattern emitted by |
| 311 | +/// migraphx-to-linalg, replacing it with just tensor.expand_shape on the |
| 312 | +/// unpadded source. rock.conv handles padding internally. |
| 313 | +/// |
| 314 | +/// Expected IR structure: |
| 315 | +/// %padded = tensor.pad %original ... |
| 316 | +/// %expanded = tensor.expand_shape %padded ... |
| 317 | +/// Replaced with: |
| 318 | +/// %expanded = tensor.expand_shape %original ... |
| 319 | +static FailureOr<Value> |
| 320 | +removePaddingFromInput(ConversionPatternRewriter &rewriter, |
| 321 | + linalg::GenericOp op, Value in, ArrayAttr padding) { |
| 322 | + bool hasPadding = llvm::any_of(padding.getValue(), [](Attribute attr) { |
| 323 | + return cast<IntegerAttr>(attr).getInt() != 0; |
| 324 | + }); |
| 325 | + if (!hasPadding) |
| 326 | + return in; |
| 327 | + |
| 328 | + auto expanded = in.getDefiningOp<tensor::ExpandShapeOp>(); |
| 329 | + auto padded = (expanded != nullptr) |
| 330 | + ? expanded->getOperand(0).getDefiningOp<tensor::PadOp>() |
| 331 | + : nullptr; |
| 332 | + // We require padding here to have one use because the code structure emitted |
| 333 | + // by the MIGraphX -> Linalg have one use. In theory, you don't need this |
| 334 | + // check, but better be safe than sorry. This goes with expanded as well |
| 335 | + if (!padded || !padded->hasOneUse()) { |
| 336 | + op.emitError("unexpected padding code structure"); |
| 337 | + return failure(); |
| 338 | + } |
| 339 | + |
| 340 | + if (!expanded || !expanded->hasOneUse()) { |
| 341 | + return op.emitError("unexpected group expansion shape code structure"); |
| 342 | + } |
| 343 | + |
| 344 | + SmallVector<int64_t> resultShape(expanded.getResultType().getShape()); |
| 345 | + // The tensor.pad operand has no group dimension: [N, G*C, spatial...]. |
| 346 | + // The expanded result has [N, G, C, spatial_padded...]. Take the first 3 |
| 347 | + // dims (N, G, C) from the expanded shape and append the unpadded spatial |
| 348 | + // dims directly from the pad source starting at position 2. |
| 349 | + auto padSourceShape = |
| 350 | + cast<RankedTensorType>(padded.getOperand(0).getType()).getShape(); |
| 351 | + resultShape.resize(3); |
| 352 | + resultShape.insert(resultShape.begin() + 3, padSourceShape.begin() + 2, |
| 353 | + padSourceShape.end()); |
| 354 | + |
| 355 | + RankedTensorType newResultType = RankedTensorType::get( |
| 356 | + resultShape, padded.getResultType().getElementType()); |
| 357 | + Value result = tensor::ExpandShapeOp::create( |
| 358 | + rewriter, expanded.getLoc(), newResultType, padded.getOperand(0), |
| 359 | + expanded.getReassociationIndices()); |
| 360 | + // erase the operations as well |
| 361 | + rewriter.eraseOp(expanded); |
| 362 | + rewriter.eraseOp(padded); |
| 363 | + return result; |
| 364 | +} |
| 365 | + |
| 366 | +namespace { |
| 367 | +struct ConvLinalgConverter final |
| 368 | + : public OpConversionPattern<linalg::GenericOp> { |
| 369 | + using OpConversionPattern<linalg::GenericOp>::OpConversionPattern; |
| 370 | + using OpConversionPattern<linalg::GenericOp>::getTypeConverter; |
| 371 | + using OpAdaptor = typename OpConversionPattern<linalg::GenericOp>::OpAdaptor; |
| 372 | + |
| 373 | + LogicalResult |
| 374 | + matchAndRewrite(linalg::GenericOp op, OpAdaptor adaptor, |
| 375 | + ConversionPatternRewriter &rewriter) const override; |
| 376 | + |
| 377 | +private: |
| 378 | + FailureOr<ConvFields> isConv(ConversionPatternRewriter &rewriter, |
| 379 | + linalg::GenericOp op) const; |
| 380 | +}; |
| 381 | +} // namespace |
| 382 | + |
| 383 | +FailureOr<ConvFields> |
| 384 | +ConvLinalgConverter::isConv(ConversionPatternRewriter &rewriter, |
| 385 | + linalg::GenericOp op) const { |
| 386 | + auto name = |
| 387 | + op->getAttrOfType<rock::LinalgConvTypeAttr>(rock::linalgConvOpAttrName); |
| 388 | + if (!name) |
| 389 | + return failure(); |
| 390 | + rock::LinalgConvType convType = name.getValue(); |
| 391 | + int64_t spatialDim = getSpatialDim(convType); |
| 392 | + // Conv1D is broadcasted into Conv2D. To check for error, we |
| 393 | + // use effectiveDim instead because it one has more stride/dilation |
| 394 | + // in the expanded dimension |
| 395 | + int64_t effectiveDim = (spatialDim == 1) ? spatialDim + 1 : spatialDim; |
| 396 | + |
| 397 | + auto convertToArrayAttr = |
| 398 | + [&](Attribute arr, ArrayRef<int64_t> dimOneDefaults = {}) -> ArrayAttr { |
| 399 | + if (!arr || !isa<ArrayAttr>(arr)) { |
| 400 | + return ArrayAttr{}; |
| 401 | + } |
| 402 | + |
| 403 | + SmallVector<int64_t, 4> values; |
| 404 | + llvm::transform( |
| 405 | + cast<ArrayAttr>(arr).getValue(), std::back_inserter(values), |
| 406 | + [](Attribute val) { return cast<IntegerAttr>(val).getInt(); }); |
| 407 | + // Conv1D is expanded into Conv2D: append identity defaults for the |
| 408 | + // extra spatial dimension (stride=1, dilation=1, pad=0). |
| 409 | + if (spatialDim == 1) |
| 410 | + values.insert(values.end(), dimOneDefaults.begin(), dimOneDefaults.end()); |
| 411 | + return rewriter.getIndexArrayAttr(values); |
| 412 | + }; |
| 413 | + |
| 414 | + auto dilation = |
| 415 | + convertToArrayAttr(op->getAttr("dilation"), /*dimOneDefaults=*/{1}); |
| 416 | + auto stride = |
| 417 | + convertToArrayAttr(op->getAttr("stride"), /*dimOneDefaults=*/{1}); |
| 418 | + if (!dilation || static_cast<int64_t>(dilation.size()) != effectiveDim) { |
| 419 | + op.emitError("invalid dilation"); |
| 420 | + return failure(); |
| 421 | + } |
| 422 | + |
| 423 | + if (!stride || static_cast<int64_t>(stride.size()) != effectiveDim) { |
| 424 | + op.emitError("invalid stride"); |
| 425 | + return failure(); |
| 426 | + } |
| 427 | + |
| 428 | + // Input format: [dim0_low, dim1_low, ..., dim0_high, dim1_high, ...] |
| 429 | + // Rock format: [dim0_low, dim0_high, dim1_low, dim1_high, ...] |
| 430 | + auto originalPadding = convertToArrayAttr(op->getAttr("pad")); |
| 431 | + if (!originalPadding) { |
| 432 | + op.emitError("no padding found"); |
| 433 | + return failure(); |
| 434 | + } |
| 435 | + int64_t numSpatial = originalPadding.size() / 2; |
| 436 | + SmallVector<Attribute, 8> interleavedPad; |
| 437 | + for (int64_t i = 0; i < numSpatial; ++i) { |
| 438 | + interleavedPad.push_back(originalPadding[i]); |
| 439 | + interleavedPad.push_back(originalPadding[numSpatial + i]); |
| 440 | + } |
| 441 | + // Conv1D is expanded into Conv2D |
| 442 | + if (spatialDim == 1) { |
| 443 | + interleavedPad.push_back(rewriter.getIndexAttr(0)); |
| 444 | + interleavedPad.push_back(rewriter.getIndexAttr(0)); |
| 445 | + } |
| 446 | + auto padding = rewriter.getArrayAttr(interleavedPad); |
| 447 | + // note that Conv1D is expanded into Conv2D |
| 448 | + if (effectiveDim * 2 != (int64_t)padding.size()) { |
| 449 | + op.emitError("invalid number of padding"); |
| 450 | + return failure(); |
| 451 | + } |
| 452 | + |
| 453 | + StringAttr perfConfig = op->getAttrOfType<StringAttr>("perf_config"); |
| 454 | + return ConvFields{convType, spatialDim, padding, |
| 455 | + stride, dilation, perfConfig}; |
| 456 | +} |
| 457 | + |
| 458 | +LogicalResult ConvLinalgConverter::matchAndRewrite( |
| 459 | + linalg::GenericOp op, OpAdaptor adaptor, |
| 460 | + ConversionPatternRewriter &rewriter) const { |
| 461 | + FailureOr<ConvFields> maybeConv = isConv(rewriter, op); |
| 462 | + if (failed(maybeConv)) |
| 463 | + return failure(); |
| 464 | + |
| 465 | + ConvFields conv = *maybeConv; |
| 466 | + Location loc = op.getLoc(); |
| 467 | + |
| 468 | + auto maybeInput = |
| 469 | + removePaddingFromInput(rewriter, op, op.getOperand(0), conv.padding); |
| 470 | + if (failed(maybeInput)) |
| 471 | + return failure(); |
| 472 | + |
| 473 | + Value input = *maybeInput; |
| 474 | + Value filter = op.getOperand(1); |
| 475 | + |
| 476 | + // Conv1D is expanded into Conv2D: unmerge the single spatial dim |
| 477 | + // into (spatial, W=1) for filter and input. |
| 478 | + int64_t effectiveSpatialDim = conv.spatialDim; |
| 479 | + if (conv.spatialDim == 1) { |
| 480 | + effectiveSpatialDim = 2; |
| 481 | + auto filterShape = cast<RankedTensorType>(filter.getType()).getShape(); |
| 482 | + rock::BottomUpTMBuilder builder(rewriter, {"g", "k", "c", "0"}, filterShape, |
| 483 | + loc); |
| 484 | + builder.passThrough({"gf", "kf", "cf"}, {0, 1, 2}, {"g", "k", "c"}); |
| 485 | + builder.unmerge({"0f", "1f"}, {3, 4}, "0", {filterShape[3], 1}); |
| 486 | + filter = rock::TransformOp::create(rewriter, loc, filter, builder.get()); |
| 487 | + |
| 488 | + auto inputShape = cast<RankedTensorType>(input.getType()).getShape(); |
| 489 | + rock::BottomUpTMBuilder b(rewriter, {"n", "g", "c", "0"}, inputShape, loc); |
| 490 | + b.passThrough({"nu", "gu", "cu"}, {0, 1, 2}, {"n", "g", "c"}); |
| 491 | + b.unmerge({"0u", "1u"}, {3, 4}, "0", {inputShape[3], 1}); |
| 492 | + input = rock::TransformOp::create(rewriter, loc, input, b.get()); |
| 493 | + } |
| 494 | + |
| 495 | + RankedTensorType linalgResultType = |
| 496 | + cast<RankedTensorType>(op.getResult(0).getType()); |
| 497 | + SmallVector<int64_t> rockShape(linalgResultType.getShape()); |
| 498 | + if (conv.spatialDim == 1) |
| 499 | + rockShape.push_back(1); |
| 500 | + RankedTensorType rockResultType = |
| 501 | + RankedTensorType::get(rockShape, linalgResultType.getElementType()); |
| 502 | + Value output = |
| 503 | + bufferization::AllocTensorOp::create(rewriter, loc, rockResultType, {}); |
| 504 | + auto cop = rock::ConvOp::create(rewriter, loc, rockResultType, filter, input, |
| 505 | + output, /*features=*/nullptr, |
| 506 | + /*blockSize=*/nullptr, /*gridSize=*/nullptr, |
| 507 | + conv.padding, conv.stride, conv.dilation, |
| 508 | + /*params=*/nullptr); |
| 509 | + // TODO: add splitk see (AIROCMLIR-696) |
| 510 | + if (conv.perfConfig) |
| 511 | + cop->setAttr("perf_config", conv.perfConfig); |
| 512 | + setConvLayoutAttrs(rewriter, cop, effectiveSpatialDim); |
| 513 | + |
| 514 | + Value result = cop.getResult(); |
| 515 | + if (conv.spatialDim == 1) { |
| 516 | + auto shape = cast<RankedTensorType>(result.getType()).getShape(); |
| 517 | + rock::BottomUpTMBuilder b(rewriter, {"n", "g", "k", "0", "1"}, shape, loc); |
| 518 | + b.passThrough({"no", "go", "ko"}, {0, 1, 2}, {"n", "g", "k"}); |
| 519 | + b.merge("0o", 3, {"0", "1"}); |
| 520 | + result = rock::TransformOp::create(rewriter, loc, result, b.get()); |
| 521 | + } |
| 522 | + |
| 523 | + rewriter.replaceOp(op, result); |
| 524 | + return success(); |
| 525 | +} |
| 526 | + |
262 | 527 | void mlir::rock::populateLinalgToRockConversionPattern( |
263 | 528 | RewritePatternSet &pattern, MLIRContext *context) { |
264 | 529 | pattern.add<MatmulConverter<linalg::BatchMatmulOp>, |
265 | 530 | MatmulConverter<linalg::MatmulOp>, ExpandStrideConverter, |
266 | | - MatmulConverter<linalg::GenericOp>>(context); |
| 531 | + MatmulConverter<linalg::GenericOp>, ConvLinalgConverter>(context); |
267 | 532 | } |
0 commit comments