@@ -376,6 +376,253 @@ ClipConverter::matchAndRewrite(migraphx::ClipOp op, OpAdaptor adaptor,
376376 return success ();
377377}
378378
379+ // ===----------------------------------------------------------------------===//
380+ // Tensor views and shape manipulation
381+ // ===----------------------------------------------------------------------===//
382+ namespace {
383+ struct BroadcastConverter final
384+ : public OpConversionPattern<migraphx::BroadcastOp> {
385+ using OpConversionPattern<migraphx::BroadcastOp>::OpConversionPattern;
386+
387+ LogicalResult
388+ matchAndRewrite (migraphx::BroadcastOp op, OpAdaptor adaptor,
389+ ConversionPatternRewriter &rewriter) const final ;
390+ };
391+
392+ struct MultiBroadcastConverter final
393+ : public OpConversionPattern<migraphx::MultiBroadcastOp> {
394+ using OpConversionPattern<migraphx::MultiBroadcastOp>::OpConversionPattern;
395+
396+ LogicalResult
397+ matchAndRewrite (migraphx::MultiBroadcastOp op, OpAdaptor adaptor,
398+ ConversionPatternRewriter &rewriter) const final ;
399+ };
400+ } // namespace
401+
402+ // / Reshape the input Value into a new RankedTensorType with newShape
403+ // / The input must have type RankedTensorType.
404+ static Value reshapeValue (ConversionPatternRewriter &rewriter, Value input,
405+ ArrayRef<int64_t > newShape) {
406+ // Although there is a tensor.reshape op, we use tensor.collapse_shape
407+ // and tensor.expand_shape since rock-view-to-transform pass doesn't
408+ // support tensor.reshape
409+ RankedTensorType currentType = cast<RankedTensorType>(input.getType ());
410+ Location loc = input.getLoc ();
411+ int64_t inputRank = currentType.getRank ();
412+ int64_t outputRank = static_cast <int64_t >(newShape.size ());
413+
414+ if (currentType.getShape () == newShape) {
415+ return input;
416+ }
417+
418+ SmallVector<ReassociationIndices> collapseReassociation (1 );
419+ SmallVector<ReassociationIndices> expandReassociation (1 );
420+ collapseReassociation[0 ].resize (inputRank);
421+ expandReassociation[0 ].resize (outputRank);
422+ std::iota (collapseReassociation[0 ].begin (), collapseReassociation[0 ].end (),
423+ 0 );
424+ std::iota (expandReassociation[0 ].begin (), expandReassociation[0 ].end (), 0 );
425+ input = tensor::CollapseShapeOp::create (rewriter, loc, input,
426+ collapseReassociation);
427+ if (cast<RankedTensorType>(input.getType ()).getShape () == newShape) {
428+ return input;
429+ }
430+ RankedTensorType resultType =
431+ RankedTensorType::get (newShape, currentType.getElementType ());
432+ input = tensor::ExpandShapeOp::create (rewriter, loc, resultType, input,
433+ expandReassociation);
434+ return input;
435+ }
436+
437+ LogicalResult
438+ BroadcastConverter::matchAndRewrite (migraphx::BroadcastOp op, OpAdaptor adaptor,
439+ ConversionPatternRewriter &rewriter) const {
440+ Location loc = op.getLoc ();
441+ migraphx::MIXRShapedType input = op.getInput ().getType ();
442+ migraphx::MIXRShapedType output = op.getOutput ().getType ();
443+
444+ RankedTensorType outputType =
445+ dyn_cast<RankedTensorType>(getTypeConverter ()->convertType (output));
446+ if (!outputType) {
447+ return op.emitError (" cannot convert output type to ranked tensor type" );
448+ }
449+
450+ uint64_t axis = op.getAxis ();
451+ uint64_t outputRank = output.getRank ();
452+
453+ uint64_t inputRank = input.getRank ();
454+ SmallVector<int64_t , 4 > dimensionAttr;
455+ llvm::transform (llvm::seq<int64_t >(0 , axis),
456+ std::back_inserter (dimensionAttr),
457+ [](int64_t val) { return val; });
458+ for (auto [index, dim] : llvm::enumerate (input.getShape ())) {
459+ // the one in the input dimension can also be broadcasted
460+ if (dim == 1 ) {
461+ dimensionAttr.push_back (index + axis);
462+ }
463+ }
464+ llvm::transform (llvm::seq<int64_t >(axis + inputRank, outputRank),
465+ std::back_inserter (dimensionAttr),
466+ [](int64_t val) { return val; });
467+
468+ // We have to remove the one dimension because it is possible that we are
469+ // broadcasting that to a different dimension
470+ auto reshaped =
471+ reshapeValue (rewriter, adaptor.getInput (),
472+ llvm::filter_to_vector (
473+ input.getShape (), [](int64_t val) { return val != 1 ; }));
474+ auto init = tensor::EmptyOp::create (rewriter, loc, outputType.getShape (),
475+ outputType.getElementType ());
476+ auto result =
477+ linalg::BroadcastOp::create (rewriter, loc, reshaped, init, dimensionAttr);
478+ rewriter.replaceOp (op, result);
479+
480+ return success ();
481+ }
482+
483+ LogicalResult MultiBroadcastConverter::matchAndRewrite (
484+ migraphx::MultiBroadcastOp op, OpAdaptor adaptor,
485+ ConversionPatternRewriter &rewriter) const {
486+ Location loc = op->getLoc ();
487+ migraphx::MIXRShapedType outMIXRType = op.getOutput ().getType ();
488+ RankedTensorType outType =
489+ cast<RankedTensorType>(getTypeConverter ()->convertType (outMIXRType));
490+ ArrayRef<int64_t > outShape = outType.getShape ();
491+ ArrayRef<int64_t > outStrides = outMIXRType.getStrides ();
492+ uint32_t inRank =
493+ cast<RankedTensorType>(adaptor.getInput ().getType ()).getRank ();
494+ uint32_t outRank = outType.getRank ();
495+ Type elemType = outType.getElementType ();
496+
497+ assert (outRank >= inRank && " MultiBroadcastOp shouldn't reduce rank. This "
498+ " should be an invariant of this operation" );
499+
500+ // If it's a splat constant, broadcast it trivially
501+ if (auto constOp = adaptor.getInput ().getDefiningOp <arith::ConstantOp>()) {
502+ if (auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue ())) {
503+ if (denseAttr && denseAttr.isSplat ()) {
504+ auto bcastConstAttr = DenseElementsAttr::get (
505+ outType, denseAttr.getSplatValue <Attribute>());
506+ rewriter.replaceOpWithNewOp <arith::ConstantOp>(op, outType,
507+ bcastConstAttr);
508+ return success ();
509+ }
510+ }
511+ }
512+
513+ // Determine broadcast dimensions (stride == 0) and non-broadcast shape
514+ SmallVector<int64_t , 4 > broadcastDimensions;
515+ SmallVector<int64_t , 4 > nonBroadcastShape;
516+ for (auto [i, stride, shape] : llvm::enumerate (outStrides, outShape)) {
517+ if (stride == 0 ) {
518+ broadcastDimensions.push_back (i);
519+ } else {
520+ nonBroadcastShape.push_back (shape);
521+ }
522+ }
523+
524+ // If no dimensions need broadcasting, just reshape to match output shape
525+ if (broadcastDimensions.empty ()) {
526+ Value result = reshapeValue (rewriter, adaptor.getInput (), outShape);
527+ rewriter.replaceOp (op, result);
528+ return success ();
529+ }
530+
531+ // Reshape input to match the non-broadcast dimensions of the output
532+ Value input = reshapeValue (rewriter, adaptor.getInput (), nonBroadcastShape);
533+
534+ auto init = tensor::EmptyOp::create (rewriter, loc, outShape, elemType);
535+ auto result = linalg::BroadcastOp::create (rewriter, loc, input, init,
536+ broadcastDimensions);
537+ rewriter.replaceOp (op, result);
538+ return success ();
539+ }
540+
541+ // ===----------------------------------------------------------------------===//
542+ // Misc. ops
543+ // ===----------------------------------------------------------------------===//
544+ namespace {
545+ struct LiteralConverter final
546+ : public OpConversionPattern<migraphx::LiteralOp> {
547+ using OpConversionPattern<migraphx::LiteralOp>::OpConversionPattern;
548+
549+ LogicalResult
550+ matchAndRewrite (migraphx::LiteralOp op, OpAdaptor adaptor,
551+ ConversionPatternRewriter &rewriter) const final ;
552+ };
553+ } // namespace
554+
555+ LogicalResult
556+ LiteralConverter::matchAndRewrite (migraphx::LiteralOp op, OpAdaptor adaptor,
557+ ConversionPatternRewriter &rewriter) const {
558+ migraphx::MIXRShapedType type = op.getResult ().getType ();
559+ RankedTensorType newType =
560+ dyn_cast<RankedTensorType>(getTypeConverter ()->convertType (type));
561+ if (!newType) {
562+ return op.emitError (" expected RankedTensorType as output" );
563+ }
564+
565+ ElementsAttr value = op.getValue ();
566+ if (value.getType () != newType) {
567+ if (value.isSplat ()) {
568+ // Get the original splat value (for example SI8 value)
569+ Attribute splatValue = value.getSplatValue <Attribute>();
570+
571+ // Reinterpret the splatValue under the new type (for example SI8 -> I8),
572+ // preserving bytes
573+ Attribute newSplatValue;
574+ if (auto intAttr = dyn_cast<IntegerAttr>(splatValue))
575+ newSplatValue =
576+ IntegerAttr::get (newType.getElementType (), intAttr.getValue ());
577+ else if (auto floatAttr = dyn_cast<FloatAttr>(splatValue))
578+ newSplatValue =
579+ FloatAttr::get (newType.getElementType (), floatAttr.getValue ());
580+ else if (auto boolAttr = dyn_cast<BoolAttr>(splatValue))
581+ // Convert BoolAttr into IntegerAttr so we don't run target
582+ // materialization for type conversion. Match the result type of
583+ // TypeConverter
584+ newSplatValue =
585+ IntegerAttr::get (newType.getElementType (), boolAttr.getValue ());
586+ else
587+ return failure ();
588+
589+ // Create the new SplatElementsAttr (for example I8 type) with preserved
590+ // value bytes
591+ value = SplatElementsAttr::get (newType, newSplatValue);
592+ } else {
593+ // For non-splat attributes, we need to convert each element to the new
594+ // type
595+ SmallVector<Attribute> convertedElements;
596+ convertedElements.reserve (value.getNumElements ());
597+
598+ for (auto it : value.getValues <Attribute>()) {
599+ Attribute convertedElement;
600+ if (auto intAttr = dyn_cast<IntegerAttr>(it))
601+ convertedElement =
602+ IntegerAttr::get (newType.getElementType (), intAttr.getValue ());
603+ else if (auto floatAttr = dyn_cast<FloatAttr>(it))
604+ convertedElement =
605+ FloatAttr::get (newType.getElementType (), floatAttr.getValue ());
606+ else if (auto boolAttr = dyn_cast<BoolAttr>(it))
607+ // Convert BoolAttr into IntegerAttr so we don't run target
608+ // materialization for type conversion. Match the result type of
609+ // TypeConverter
610+ convertedElement =
611+ IntegerAttr::get (newType.getElementType (), boolAttr.getValue ());
612+ else
613+ return failure ();
614+
615+ convertedElements.push_back (convertedElement);
616+ }
617+
618+ value = DenseElementsAttr::get (newType, convertedElements);
619+ }
620+ }
621+
622+ rewriter.replaceOpWithNewOp <arith::ConstantOp>(op, newType, value);
623+ return success ();
624+ }
625+
379626// ===----------------------------------------------------------------------===//
380627// populateMIGraphXToLinalg* method
381628// ===----------------------------------------------------------------------===//
@@ -396,7 +643,9 @@ void mlir::migraphx::populateMIGraphXToLinalgConversionPatterns(
396643 ElementwiseConverter<migraphx::SqrtOp, linalg::SqrtOp>,
397644 ElementwiseConverter<migraphx::TanhOp, linalg::TanhOp>,
398645 ElementwiseConverter<migraphx::RecipOp, linalg::ReciprocalOp>,
399- ReluConverter, ClipConverter>(converter, patterns.getContext ());
646+ ReluConverter, ClipConverter, BroadcastConverter,
647+ MultiBroadcastConverter, LiteralConverter>(converter,
648+ patterns.getContext ());
400649}
401650
402651void mlir::migraphx::populateMIGraphXFuncBoundaryToLinalgConversionPatterns (
0 commit comments