Skip to content

Commit a8ae8ac

Browse files
authored
[AIROCMLIR-552] Added Broadcasting Linalg Lowering Path (#2270)
* [AIROCMLIR-552] Added broadcasting * Address comments
1 parent 839eb35 commit a8ae8ac

3 files changed

Lines changed: 338 additions & 13 deletions

File tree

mlir/lib/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.cpp

Lines changed: 250 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,253 @@ ClipConverter::matchAndRewrite(migraphx::ClipOp op, OpAdaptor adaptor,
376376
return success();
377377
}
378378

379+
//===----------------------------------------------------------------------===//
380+
// Tensor views and shape manipulation
381+
//===----------------------------------------------------------------------===//
382+
namespace {
383+
struct BroadcastConverter final
384+
: public OpConversionPattern<migraphx::BroadcastOp> {
385+
using OpConversionPattern<migraphx::BroadcastOp>::OpConversionPattern;
386+
387+
LogicalResult
388+
matchAndRewrite(migraphx::BroadcastOp op, OpAdaptor adaptor,
389+
ConversionPatternRewriter &rewriter) const final;
390+
};
391+
392+
struct MultiBroadcastConverter final
393+
: public OpConversionPattern<migraphx::MultiBroadcastOp> {
394+
using OpConversionPattern<migraphx::MultiBroadcastOp>::OpConversionPattern;
395+
396+
LogicalResult
397+
matchAndRewrite(migraphx::MultiBroadcastOp op, OpAdaptor adaptor,
398+
ConversionPatternRewriter &rewriter) const final;
399+
};
400+
} // namespace
401+
402+
/// Reshape the input Value into a new RankedTensorType with newShape
403+
/// The input must have type RankedTensorType.
404+
static Value reshapeValue(ConversionPatternRewriter &rewriter, Value input,
405+
ArrayRef<int64_t> newShape) {
406+
// Although there is a tensor.reshape op, we use tensor.collapse_shape
407+
// and tensor.expand_shape since rock-view-to-transform pass doesn't
408+
// support tensor.reshape
409+
RankedTensorType currentType = cast<RankedTensorType>(input.getType());
410+
Location loc = input.getLoc();
411+
int64_t inputRank = currentType.getRank();
412+
int64_t outputRank = static_cast<int64_t>(newShape.size());
413+
414+
if (currentType.getShape() == newShape) {
415+
return input;
416+
}
417+
418+
SmallVector<ReassociationIndices> collapseReassociation(1);
419+
SmallVector<ReassociationIndices> expandReassociation(1);
420+
collapseReassociation[0].resize(inputRank);
421+
expandReassociation[0].resize(outputRank);
422+
std::iota(collapseReassociation[0].begin(), collapseReassociation[0].end(),
423+
0);
424+
std::iota(expandReassociation[0].begin(), expandReassociation[0].end(), 0);
425+
input = tensor::CollapseShapeOp::create(rewriter, loc, input,
426+
collapseReassociation);
427+
if (cast<RankedTensorType>(input.getType()).getShape() == newShape) {
428+
return input;
429+
}
430+
RankedTensorType resultType =
431+
RankedTensorType::get(newShape, currentType.getElementType());
432+
input = tensor::ExpandShapeOp::create(rewriter, loc, resultType, input,
433+
expandReassociation);
434+
return input;
435+
}
436+
437+
LogicalResult
438+
BroadcastConverter::matchAndRewrite(migraphx::BroadcastOp op, OpAdaptor adaptor,
439+
ConversionPatternRewriter &rewriter) const {
440+
Location loc = op.getLoc();
441+
migraphx::MIXRShapedType input = op.getInput().getType();
442+
migraphx::MIXRShapedType output = op.getOutput().getType();
443+
444+
RankedTensorType outputType =
445+
dyn_cast<RankedTensorType>(getTypeConverter()->convertType(output));
446+
if (!outputType) {
447+
return op.emitError("cannot convert output type to ranked tensor type");
448+
}
449+
450+
uint64_t axis = op.getAxis();
451+
uint64_t outputRank = output.getRank();
452+
453+
uint64_t inputRank = input.getRank();
454+
SmallVector<int64_t, 4> dimensionAttr;
455+
llvm::transform(llvm::seq<int64_t>(0, axis),
456+
std::back_inserter(dimensionAttr),
457+
[](int64_t val) { return val; });
458+
for (auto [index, dim] : llvm::enumerate(input.getShape())) {
459+
// the one in the input dimension can also be broadcasted
460+
if (dim == 1) {
461+
dimensionAttr.push_back(index + axis);
462+
}
463+
}
464+
llvm::transform(llvm::seq<int64_t>(axis + inputRank, outputRank),
465+
std::back_inserter(dimensionAttr),
466+
[](int64_t val) { return val; });
467+
468+
// We have to remove the one dimension because it is possible that we are
469+
// broadcasting that to a different dimension
470+
auto reshaped =
471+
reshapeValue(rewriter, adaptor.getInput(),
472+
llvm::filter_to_vector(
473+
input.getShape(), [](int64_t val) { return val != 1; }));
474+
auto init = tensor::EmptyOp::create(rewriter, loc, outputType.getShape(),
475+
outputType.getElementType());
476+
auto result =
477+
linalg::BroadcastOp::create(rewriter, loc, reshaped, init, dimensionAttr);
478+
rewriter.replaceOp(op, result);
479+
480+
return success();
481+
}
482+
483+
LogicalResult MultiBroadcastConverter::matchAndRewrite(
484+
migraphx::MultiBroadcastOp op, OpAdaptor adaptor,
485+
ConversionPatternRewriter &rewriter) const {
486+
Location loc = op->getLoc();
487+
migraphx::MIXRShapedType outMIXRType = op.getOutput().getType();
488+
RankedTensorType outType =
489+
cast<RankedTensorType>(getTypeConverter()->convertType(outMIXRType));
490+
ArrayRef<int64_t> outShape = outType.getShape();
491+
ArrayRef<int64_t> outStrides = outMIXRType.getStrides();
492+
uint32_t inRank =
493+
cast<RankedTensorType>(adaptor.getInput().getType()).getRank();
494+
uint32_t outRank = outType.getRank();
495+
Type elemType = outType.getElementType();
496+
497+
assert(outRank >= inRank && "MultiBroadcastOp shouldn't reduce rank. This "
498+
"should be an invariant of this operation");
499+
500+
// If it's a splat constant, broadcast it trivially
501+
if (auto constOp = adaptor.getInput().getDefiningOp<arith::ConstantOp>()) {
502+
if (auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue())) {
503+
if (denseAttr && denseAttr.isSplat()) {
504+
auto bcastConstAttr = DenseElementsAttr::get(
505+
outType, denseAttr.getSplatValue<Attribute>());
506+
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, outType,
507+
bcastConstAttr);
508+
return success();
509+
}
510+
}
511+
}
512+
513+
// Determine broadcast dimensions (stride == 0) and non-broadcast shape
514+
SmallVector<int64_t, 4> broadcastDimensions;
515+
SmallVector<int64_t, 4> nonBroadcastShape;
516+
for (auto [i, stride, shape] : llvm::enumerate(outStrides, outShape)) {
517+
if (stride == 0) {
518+
broadcastDimensions.push_back(i);
519+
} else {
520+
nonBroadcastShape.push_back(shape);
521+
}
522+
}
523+
524+
// If no dimensions need broadcasting, just reshape to match output shape
525+
if (broadcastDimensions.empty()) {
526+
Value result = reshapeValue(rewriter, adaptor.getInput(), outShape);
527+
rewriter.replaceOp(op, result);
528+
return success();
529+
}
530+
531+
// Reshape input to match the non-broadcast dimensions of the output
532+
Value input = reshapeValue(rewriter, adaptor.getInput(), nonBroadcastShape);
533+
534+
auto init = tensor::EmptyOp::create(rewriter, loc, outShape, elemType);
535+
auto result = linalg::BroadcastOp::create(rewriter, loc, input, init,
536+
broadcastDimensions);
537+
rewriter.replaceOp(op, result);
538+
return success();
539+
}
540+
541+
//===----------------------------------------------------------------------===//
542+
// Misc. ops
543+
//===----------------------------------------------------------------------===//
544+
namespace {
545+
struct LiteralConverter final
546+
: public OpConversionPattern<migraphx::LiteralOp> {
547+
using OpConversionPattern<migraphx::LiteralOp>::OpConversionPattern;
548+
549+
LogicalResult
550+
matchAndRewrite(migraphx::LiteralOp op, OpAdaptor adaptor,
551+
ConversionPatternRewriter &rewriter) const final;
552+
};
553+
} // namespace
554+
555+
LogicalResult
556+
LiteralConverter::matchAndRewrite(migraphx::LiteralOp op, OpAdaptor adaptor,
557+
ConversionPatternRewriter &rewriter) const {
558+
migraphx::MIXRShapedType type = op.getResult().getType();
559+
RankedTensorType newType =
560+
dyn_cast<RankedTensorType>(getTypeConverter()->convertType(type));
561+
if (!newType) {
562+
return op.emitError("expected RankedTensorType as output");
563+
}
564+
565+
ElementsAttr value = op.getValue();
566+
if (value.getType() != newType) {
567+
if (value.isSplat()) {
568+
// Get the original splat value (for example SI8 value)
569+
Attribute splatValue = value.getSplatValue<Attribute>();
570+
571+
// Reinterpret the splatValue under the new type (for example SI8 -> I8),
572+
// preserving bytes
573+
Attribute newSplatValue;
574+
if (auto intAttr = dyn_cast<IntegerAttr>(splatValue))
575+
newSplatValue =
576+
IntegerAttr::get(newType.getElementType(), intAttr.getValue());
577+
else if (auto floatAttr = dyn_cast<FloatAttr>(splatValue))
578+
newSplatValue =
579+
FloatAttr::get(newType.getElementType(), floatAttr.getValue());
580+
else if (auto boolAttr = dyn_cast<BoolAttr>(splatValue))
581+
// Convert BoolAttr into IntegerAttr so we don't run target
582+
// materialization for type conversion. Match the result type of
583+
// TypeConverter
584+
newSplatValue =
585+
IntegerAttr::get(newType.getElementType(), boolAttr.getValue());
586+
else
587+
return failure();
588+
589+
// Create the new SplatElementsAttr (for example I8 type) with preserved
590+
// value bytes
591+
value = SplatElementsAttr::get(newType, newSplatValue);
592+
} else {
593+
// For non-splat attributes, we need to convert each element to the new
594+
// type
595+
SmallVector<Attribute> convertedElements;
596+
convertedElements.reserve(value.getNumElements());
597+
598+
for (auto it : value.getValues<Attribute>()) {
599+
Attribute convertedElement;
600+
if (auto intAttr = dyn_cast<IntegerAttr>(it))
601+
convertedElement =
602+
IntegerAttr::get(newType.getElementType(), intAttr.getValue());
603+
else if (auto floatAttr = dyn_cast<FloatAttr>(it))
604+
convertedElement =
605+
FloatAttr::get(newType.getElementType(), floatAttr.getValue());
606+
else if (auto boolAttr = dyn_cast<BoolAttr>(it))
607+
// Convert BoolAttr into IntegerAttr so we don't run target
608+
// materialization for type conversion. Match the result type of
609+
// TypeConverter
610+
convertedElement =
611+
IntegerAttr::get(newType.getElementType(), boolAttr.getValue());
612+
else
613+
return failure();
614+
615+
convertedElements.push_back(convertedElement);
616+
}
617+
618+
value = DenseElementsAttr::get(newType, convertedElements);
619+
}
620+
}
621+
622+
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newType, value);
623+
return success();
624+
}
625+
379626
//===----------------------------------------------------------------------===//
380627
// populateMIGraphXToLinalg* method
381628
//===----------------------------------------------------------------------===//
@@ -396,7 +643,9 @@ void mlir::migraphx::populateMIGraphXToLinalgConversionPatterns(
396643
ElementwiseConverter<migraphx::SqrtOp, linalg::SqrtOp>,
397644
ElementwiseConverter<migraphx::TanhOp, linalg::TanhOp>,
398645
ElementwiseConverter<migraphx::RecipOp, linalg::ReciprocalOp>,
399-
ReluConverter, ClipConverter>(converter, patterns.getContext());
646+
ReluConverter, ClipConverter, BroadcastConverter,
647+
MultiBroadcastConverter, LiteralConverter>(converter,
648+
patterns.getContext());
400649
}
401650

402651
void mlir::migraphx::populateMIGraphXFuncBoundaryToLinalgConversionPatterns(

mlir/test/Conversion/MIGraphXToLinalg/migraphx-to-linalg-not-implemented.mlir

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -109,18 +109,6 @@ func.func @func_slice(%arg0: !migraphx.shaped<1x1xf32, 1x1>, %arg1: !migraphx.sh
109109
func.return
110110
}
111111

112-
func.func @func_broadcast(%arg0: !migraphx.shaped<1x1xf32, 1x1>, %arg1: !migraphx.shaped<1x1xf32, 1x1>) {
113-
// expected-error @+1{{failed to legalize operation 'migraphx.broadcast'}}
114-
migraphx.broadcast %arg0 {axis = 0 : i64, out_lens = [1, 1]}: <1x1xf32, 1x1> -> <1x1xf32, 1x1>
115-
func.return
116-
}
117-
118-
func.func @func_multibroadcast(%arg0: !migraphx.shaped<1x1xi8, 1x1>, %arg1: !migraphx.shaped<1x1xf32, 1x1>) {
119-
// expected-error @+1{{failed to legalize operation 'migraphx.multibroadcast'}}
120-
migraphx.multibroadcast %arg0 {out_lens = [1, 1]}: <1x1xi8, 1x1> -> <1x1xi8, 1x1>
121-
func.return
122-
}
123-
124112
func.func @func_quant_dot(%arg0: !migraphx.shaped<1x1xf8E4M3FN, 1x1>, %arg1: !migraphx.shaped<1x1xf8E4M3FN, 1x1>) {
125113
// expected-error @+1{{failed to legalize operation 'migraphx.quant_dot'}}
126114
migraphx.quant_dot %arg0, %arg1: <1x1xf8E4M3FN, 1x1>, <1x1xf8E4M3FN, 1x1> -> <1x1xf32, 1x1>

mlir/test/Conversion/MIGraphXToLinalg/mixr-to-linalg-ops.mlir

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,91 @@ func.func @clip_i32(%arg0: !migraphx.shaped<64x64xi32, 64x1>, %arg1: !migraphx.s
151151
%0 = migraphx.clip %arg0, %arg1, %arg2 : <64x64xi32, 64x1>, <64x64xi32, 64x1>, <64x64xi32, 64x1> -> <64x64xi32, 64x1>
152152
return %0 : !migraphx.shaped<64x64xi32, 64x1>
153153
}
154+
155+
// Literal/Broadcasting test
156+
157+
// CHECK-LABEL: @matmul_broadcast_op(
158+
// CHECK-SAME: %[[arg0:.*]]: tensor{{.*}}, %[[arg1:.*]]: tensor{{.*}}, %[[arg2:.*]]: tensor{{.*}})
159+
// CHECK-DAG: %[[expanded:.*]] = tensor.expand_shape %[[arg0]] {{.*}} into tensor<64x64x2304xf16>
160+
// CHECK-DAG: %[[expanded_0:.*]] = tensor.expand_shape %[[arg1]] {{.*}} into tensor<64x64x768xf16>
161+
// CHECK-DAG: %[[expanded_1:.*]] = tensor.expand_shape %[[arg2]] {{.*}} into tensor<1x768x2304xf16>
162+
// CHECK-DAG: %[[collapsed:.*]] = tensor.collapse_shape %[[expanded_1]] {{.*}} into tensor<1769472xf16>
163+
// CHECK-DAG: %[[expanded_2:.*]] = tensor.expand_shape %[[collapsed]] {{.*}} into tensor<768x2304xf16>
164+
// CHECK-DAG: %[[broadcasted:.*]] = linalg.broadcast ins(%[[expanded_2]] : tensor<768x2304xf16>) outs({{.*}} : tensor<64x768x2304xf16>) dimensions = [0]
165+
// CHECK-DAG: %[[cst:.*]] = arith.constant dense<0.000000e+00> : tensor<64x64x2304xf16>
166+
// CHECK-DAG: %[[matmul:.*]] = linalg.batch_matmul ins(%[[expanded_0]], %[[broadcasted]] : {{.*}}) outs(%[[cst]] : {{.*}})
167+
// CHECK-DAG: %[[add:.*]] = linalg.add ins(%[[matmul]], %[[expanded]] : {{.*}}) outs({{.*}})
168+
// CHECK-DAG: %[[collapsed_3:.*]] = tensor.collapse_shape %[[add]]
169+
// CHECK-DAG: return %[[collapsed_3]]
170+
func.func @matmul_broadcast_op(%arg0: !migraphx.shaped<64x64x2304xf16, 147456x2304x1>, %arg1: !migraphx.shaped<64x64x768xf16, 49152x768x1>, %arg2: !migraphx.shaped<1x768x2304xf16, 1769472x2304x1>) -> !migraphx.shaped<64x64x2304xf16, 147456x2304x1> {
171+
%0 = migraphx.broadcast %arg2 {axis = 0, out_lens = [64, 768, 2304]} : <1x768x2304xf16, 1769472x2304x1> -> <64x768x2304xf16, 0x2304x1>
172+
%1 = migraphx.dot %arg1, %0 : <64x64x768xf16, 49152x768x1>, <64x768x2304xf16, 0x2304x1> -> <64x64x2304xf16, 147456x2304x1>
173+
%2 = migraphx.add %1, %arg0 : <64x64x2304xf16, 147456x2304x1>, <64x64x2304xf16, 147456x2304x1> -> <64x64x2304xf16, 147456x2304x1>
174+
return %2 : !migraphx.shaped<64x64x2304xf16, 147456x2304x1>
175+
}
176+
177+
// CHECK-LABEL: @mbcast_add(
178+
// CHECK-SAME: %[[arg0:.*]]: tensor{{.*}}, %[[arg1:.*]]: tensor{{.*}})
179+
// CHECK-DAG: %[[expanded:.*]] = tensor.expand_shape %[[arg0]] {{.*}} into tensor<1x64x112x112xf32>
180+
// CHECK-DAG: %[[expanded_0:.*]] = tensor.expand_shape %[[arg1]] {{.*}} into tensor<1x64x1x1xf32>
181+
// CHECK-DAG: %[[collapsed:.*]] = tensor.collapse_shape %[[expanded_0]] {{.*}} into tensor<64xf32>
182+
// CHECK-DAG: %[[broadcasted:.*]] = linalg.broadcast ins(%[[collapsed]] : tensor<64xf32>) outs({{.*}} : tensor<1x64x112x112xf32>) dimensions = [0, 2, 3]
183+
// CHECK-DAG: %[[add:.*]] = linalg.add ins(%[[expanded]], %[[broadcasted]] : {{.*}}) outs({{.*}})
184+
// CHECK-DAG: %[[collapsed_2:.*]] = tensor.collapse_shape %[[add]]
185+
// CHECK-DAG: return %[[collapsed_2]]
186+
func.func @mbcast_add(
187+
%arg0: !migraphx.shaped<1x64x112x112xf32, 802816x12544x112x1>,
188+
%arg1: !migraphx.shaped<1x64x1x1xf32, 64x1x1x1>
189+
) -> !migraphx.shaped<1x64x112x112xf32, 802816x12544x112x1> {
190+
%0 = migraphx.multibroadcast %arg1 {out_lens = [1, 64, 112, 112]} : <1x64x1x1xf32, 64x1x1x1> -> <1x64x112x112xf32, 0x1x0x0>
191+
%1 = migraphx.add %arg0, %0 : <1x64x112x112xf32, 802816x12544x112x1>, <1x64x112x112xf32, 0x1x0x0> -> <1x64x112x112xf32, 802816x12544x112x1>
192+
return %1 : !migraphx.shaped<1x64x112x112xf32, 802816x12544x112x1>
193+
194+
}
195+
// CHECK-LABEL: @literal_splat_f32()
196+
// CHECK-DAG: %[[cst:.*]] = arith.constant dense<0.000000e+00> : tensor<4x3xf32>
197+
// CHECK-DAG: %[[collapsed:.*]] = tensor.collapse_shape %[[cst]]
198+
// CHECK-DAG: return %[[collapsed]]
199+
func.func @literal_splat_f32() -> !migraphx.shaped<4x3xf32, 3x1> {
200+
%0 = migraphx.literal (dense<0.0> : tensor<4x3xf32>) : <4x3xf32, 3x1>
201+
return %0 : !migraphx.shaped<4x3xf32, 3x1>
202+
}
203+
204+
// CHECK-LABEL: @literal(
205+
// CHECK-SAME: %[[arg0:.*]]: tensor{{.*}})
206+
// CHECK-DAG: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<16xf32>
207+
// CHECK-DAG: %[[collapsed:.*]] = tensor.collapse_shape %[[cst]]
208+
// CHECK-DAG: return %[[collapsed]]
209+
func.func @literal(%arg0: !migraphx.shaped<16xf32, 1>) -> !migraphx.shaped<16xf32, 1> {
210+
%cst = migraphx.literal (dense<1.0> : tensor<16xf32>) : <16xf32, 1>
211+
return %cst : !migraphx.shaped<16xf32, 1>
212+
}
213+
214+
// CHECK-LABEL: @literal_dense_si32
215+
// CHECK-DAG: %[[cst:.*]] = arith.constant dense<{{.*}}> : tensor<2x2xi32>
216+
func.func @literal_dense_si32() -> !migraphx.shaped<2x2xsi32, 2x1> {
217+
%0 = migraphx.literal (dense<[[0, 1], [2, 3]]> : tensor<2x2xsi32>) : <2x2xsi32, 2x1>
218+
return %0 : !migraphx.shaped<2x2xsi32, 2x1>
219+
}
220+
221+
// CHECK-LABEL: @scalar_multibroadcast_test
222+
// CHECK-DAG: %[[cst_0:.*]] = arith.constant dense<{{.*}}> : tensor<2x2xf32>
223+
// CHECK-DAG: %[[zero:.*]] = tensor.empty
224+
// CHECK-DAG: %[[one:.*]] = linalg.add ins(%[[cst_0]], %[[cst_0]] : {{.*}}) outs(%[[zero]] : {{.*}})
225+
func.func @scalar_multibroadcast_test() -> !migraphx.shaped<2x2xf32, 2x1> {
226+
%test = migraphx.literal (dense<0.0> : tensor<f32>) : <f32>
227+
%result = migraphx.multibroadcast %test {out_dyn_dims = [], out_lens = [2, 2]} : <f32> -> <2x2xf32, 0x0>
228+
%sum = migraphx.add %result, %result : <2x2xf32, 0x0>, <2x2xf32, 0x0> -> <2x2xf32, 2x1>
229+
return %sum : !migraphx.shaped<2x2xf32, 2x1>
230+
}
231+
232+
// CHECK-LABEL: @scalar_broadcast_test
233+
// CHECK-DAG: %[[cst:.*]] = arith.constant dense<0.000000e+00> : tensor<f32>
234+
// CHECK-DAG: %[[zero:.*]] = tensor.empty()
235+
// CHECK-DAG: %[[broadcasted:.*]] = linalg.broadcast ins(%[[cst]] : {{.*}}) outs(%[[zero]] : {{.*}}) dimensions = [0, 1]
236+
func.func @scalar_broadcast_test() -> !migraphx.shaped<2x2xf32, 2x1> {
237+
%test = migraphx.literal (dense<0.0> : tensor<f32>) : <f32>
238+
%result = migraphx.broadcast %test {axis = 1 : i64, out_lens = [2, 2]} : <f32> -> <2x2xf32, 0x0>
239+
%sum = migraphx.add %result, %result : <2x2xf32, 0x0>, <2x2xf32, 0x0> -> <2x2xf32, 2x1>
240+
return %sum : !migraphx.shaped<2x2xf32, 2x1>
241+
}

0 commit comments

Comments
 (0)