diff --git a/shardy/dialect/mpmd/ir/utils.cc b/shardy/dialect/mpmd/ir/utils.cc index 004653765..b9478c1c4 100644 --- a/shardy/dialect/mpmd/ir/utils.cc +++ b/shardy/dialect/mpmd/ir/utils.cc @@ -379,7 +379,7 @@ SmallVector GetMpmdDataflowEdges(FuncOp func_op) { FragmentOp WrapOpWithFragment( Operation* op, StringRef mesh_name, RewriterBase& rewriter, - std::function should_replace_use) { + StringRef inferred_by, std::function should_replace_use) { // We set the insertion point right before `op` so assigns of operands will be // in the right place regardless of previous insertion point. rewriter.setInsertionPoint(op); @@ -437,6 +437,10 @@ FragmentOp WrapOpWithFragment( return block_builder.clone(*op, mapping)->getResults(); }); + fragment_op->setAttr( + kInferredByAttr, + rewriter.getArrayAttr({rewriter.getStringAttr(inferred_by)})); + // Unassign all fragment results and replace all uses of `op` with the // corresponding unassign op for which `should_replace_use` returns true. for (auto [original_result, fragment_result] : diff --git a/shardy/dialect/mpmd/ir/utils.h b/shardy/dialect/mpmd/ir/utils.h index 9e3cdec97..ec11ad61f 100644 --- a/shardy/dialect/mpmd/ir/utils.h +++ b/shardy/dialect/mpmd/ir/utils.h @@ -40,18 +40,16 @@ limitations under the License. namespace mlir::mpmd { -// Globsl sdy mesh name. -constexpr StringRef kGlobalMeshName = "mesh"; - +// Global sdy mesh name. +inline constexpr StringRef kGlobalMeshName = "mesh"; // The function attribute that holds the SPMD mesh. -constexpr StringRef kMeshShapeAttr = "mesh_shape"; +inline constexpr StringRef kMeshShapeAttr = "mesh_shape"; // The function attribute that holds the MPMD topology. -constexpr StringRef kTopologyAttr = "topology"; - +inline constexpr StringRef kTopologyAttr = "topology"; // The suffix of the mesh name for a CPU mesh. // LINT.IfChange -constexpr StringRef kCpuMeshSuffix = "/cpu"; +inline constexpr StringRef kCpuMeshSuffix = "/cpu"; // LINT.ThenChange( // https://github.com/openxla/shardy/blob/main/shardy/integrations/python/jax/mpmd/types.py // ) @@ -76,6 +74,9 @@ inline constexpr StringRef kRematAttributeName = "remat"; inline constexpr StringRef kJaxResultInfoAttr = "jax.result_info"; +// The attribute that holds the list of pass names that inferred a fragment. +inline constexpr StringRef kInferredByAttr = "mpmd.inferred_by"; + template std::string StrCat(Args&&... args) { std::string result; @@ -261,6 +262,7 @@ SmallVector GetMpmdDataflowEdges(func::FuncOp func_op); // `should_replace_use` returns true. FragmentOp WrapOpWithFragment( Operation* op, StringRef mesh_name, RewriterBase& rewriter, + StringRef inferred_by, std::function should_replace_use = [](OpOperand&) { return true; }); diff --git a/shardy/dialect/mpmd/transforms/common/BUILD b/shardy/dialect/mpmd/transforms/common/BUILD index 9ed69193c..bbc8f1836 100644 --- a/shardy/dialect/mpmd/transforms/common/BUILD +++ b/shardy/dialect/mpmd/transforms/common/BUILD @@ -44,6 +44,7 @@ cc_library( "split_bwd_fragments.cc", "uniquify_function_inputs_outputs.cc", "unroll_for_loops.cc", + "wrap_block_arg_returns.cc", ], hdrs = [ "merge_fragments.h", diff --git a/shardy/dialect/mpmd/transforms/common/merge_fragments.cc b/shardy/dialect/mpmd/transforms/common/merge_fragments.cc index 3cd08a4ff..d162e9c0f 100644 --- a/shardy/dialect/mpmd/transforms/common/merge_fragments.cc +++ b/shardy/dialect/mpmd/transforms/common/merge_fragments.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" @@ -182,6 +183,31 @@ std::optional MergeCallCounters(FragmentOp producer_op, return std::nullopt; } +std::optional MergeInferredByAttributes(FragmentOp producer_op, + FragmentOp consumer_op) { + ArrayAttr producer_inferred_by = + producer_op->getAttrOfType(kInferredByAttr); + ArrayAttr consumer_inferred_by = + consumer_op->getAttrOfType(kInferredByAttr); + + if (!producer_inferred_by && !consumer_inferred_by) { + return std::nullopt; + } + + llvm::SetVector combined_inferred_by; + if (producer_inferred_by) { + combined_inferred_by.insert(producer_inferred_by.begin(), + producer_inferred_by.end()); + } + if (consumer_inferred_by) { + combined_inferred_by.insert(consumer_inferred_by.begin(), + consumer_inferred_by.end()); + } + + IRRewriter rewriter(producer_op.getContext()); + return rewriter.getArrayAttr(combined_inferred_by.takeVector()); +} + // Returns a list of attributes that must be preserved in the merged fragment. // Note: origins are preserved by default and require no extra work. SmallVector> MergedAttributes( @@ -194,6 +220,12 @@ SmallVector> MergedAttributes( attributes.emplace_back(kCallCounterAttrName, rewriter.getUI32IntegerAttr(*merged_call_count)); } + + if (std::optional merged_inferred_by = + MergeInferredByAttributes(producer_op, consumer_op)) { + attributes.emplace_back(kInferredByAttr, *merged_inferred_by); + } + return attributes; } diff --git a/shardy/dialect/mpmd/transforms/common/passes.td b/shardy/dialect/mpmd/transforms/common/passes.td index 0b7dabc4b..700e7bd95 100644 --- a/shardy/dialect/mpmd/transforms/common/passes.td +++ b/shardy/dialect/mpmd/transforms/common/passes.td @@ -422,6 +422,22 @@ def RemoveTransferCyclesPass : }]; } +def WrapBlockArgReturnsPass : + PassBase<"mpmd-wrap-block-arg-returns", "DistributedFunctionPass"> { + let summary = "Wraps block arguments that are directly returned by the " + "function in identity fragments."; + let description = [{ + If a function directly returns a block argument, this pass creates an + identity fragment for that block argument, ensuring that sharding + constraints can be applied to an op result rather than a block argument. + + This guarantees that values are passed by value to the function, not by + reference. + }]; + + let dependentDialects = ["mlir::mpmd::MpmdDialect"]; +} + // TODO(jupvfranco): we should create these copies using ifrt_ir reshard. We are // not ready for that yet as we need better support for donation in the // presence of reshards. diff --git a/shardy/dialect/mpmd/transforms/common/test/uniquify_function_inputs_outputs_with_reshard.mlir b/shardy/dialect/mpmd/transforms/common/test/uniquify_function_inputs_outputs_with_reshard.mlir index 789cdbb02..ce7b13c29 100644 --- a/shardy/dialect/mpmd/transforms/common/test/uniquify_function_inputs_outputs_with_reshard.mlir +++ b/shardy/dialect/mpmd/transforms/common/test/uniquify_function_inputs_outputs_with_reshard.mlir @@ -28,7 +28,7 @@ func.func @single_mesh_one_return_operand(%arg0: !mesh_1_tensor) -> (!mesh_1_ten } { // CHECK-NEXT: %[[F1:.*]] = mpmd.fragment // CHECK: %[[F2:.*]] = mpmd.fragment - // CHECK: %[[UF:.*]]:2 = mpmd.fragment (%[[F1]]) (%arg1: tensor<4xf32>) { + // CHECK: %[[UF:.*]]:2 = mpmd.fragment (%[[F1]]) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xf32>) { // CHECK: mpmd.return %arg1, %arg1 : tensor<4xf32>, tensor<4xf32> // CHECK: %[[F2]], %[[UF]]#0, %[[UF]]#1 %0 = mpmd.fragment (%arg0) (%arg1: tensor<4xf32>) { @@ -49,7 +49,7 @@ func.func @needs_fragment_for_m1_with_many_values(%arg0: !mesh_1_tensor, %arg1: // CHECK-NEXT: %[[F1:.*]] = mpmd.fragment // CHECK: %[[F2:.*]] = mpmd.fragment // CHECK: %[[F3:.*]] = mpmd.fragment - // CHECK: %[[UF:.*]]:5 = mpmd.fragment (%[[F1]], %[[F3]]) (%[[A1:.*]]: tensor<4xf32>, %[[A2:.*]]: tensor<4xf32>) + // CHECK: %[[UF:.*]]:5 = mpmd.fragment (%[[F1]], %[[F3]]) {mpmd.inferred_by = ["uniquify"]} (%[[A1:.*]]: tensor<4xf32>, %[[A2:.*]]: tensor<4xf32>) // CHECK-NEXT: mpmd.return %[[A1]], %[[A1]], %[[A2]], %[[A2]], %[[A2]] // CHECK-NEXT: } // CHECK-NEXT: return %[[F2]], %[[UF]]#0, %[[UF]]#2, %[[UF]]#1, %[[UF]]#3, %[[UF]]#4 @@ -70,8 +70,8 @@ func.func @needs_fragment_for_m1_and_m2(%arg0: !mesh_1_tensor, %arg1: !mesh_2_te ) -> (!mesh_1_tensor, !mesh_2_tensor, !mesh_2_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor) attributes { "topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["x"=2]>>> } { - // CHECK: %[[UF1:.*]]:4 = mpmd.fragment - // CHECK: %[[UF2:.*]]:2 = mpmd.fragment + // CHECK: %[[UF1:.*]]:4 = mpmd.fragment ({{.*}}) {mpmd.inferred_by = ["uniquify"]} + // CHECK: %[[UF2:.*]]:2 = mpmd.fragment ({{.*}}) {mpmd.inferred_by = ["uniquify"]} // CHECK: return %[[UF1]]#0, %[[UF2]]#0, %[[UF2]]#1, %[[UF1]]#2, %[[UF1]]#1, %[[UF1]]#3 %0 = mpmd.fragment (%arg0) (%arg2: tensor<4xf32>) { mpmd.return %arg2 : tensor<4xf32> @@ -97,7 +97,7 @@ func.func @single_mesh_one_return_operand_with_global_view(%arg0: !dist_mesh_ten } { // CHECK-NEXT: %[[F1:.*]] = mpmd.fragment // CHECK: %[[F2:.*]] = mpmd.fragment - // CHECK: %[[UF:.*]]:2 = mpmd.fragment (%[[F1]]) (%arg1: tensor<4xf32>) { + // CHECK: %[[UF:.*]]:2 = mpmd.fragment (%[[F1]]) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xf32>) { // CHECK: mpmd.return %arg1, %arg1 : tensor<4xf32>, tensor<4xf32> // CHECK: %[[F2]], %[[UF]]#0, %[[UF]]#1 %0 = mpmd.fragment (%arg0) (%arg1: tensor<4xf32>) { @@ -122,7 +122,7 @@ func.func @f(%arg0: !mesh_tensor) -> (!mesh_tensor, !mesh_tensor, !mesh_tensor) // CHECK-NEXT: %[[F1:.*]] = mpmd.fragment (%arg0) (%arg1: tensor<4xui32>) { // CHECK-NEXT: return %arg1 // CHECK-NEXT: } - // CHECK-NEXT: %[[F2:.*]]:2 = mpmd.fragment (%arg0) (%arg1: tensor<4xui32>) { + // CHECK-NEXT: %[[F2:.*]]:2 = mpmd.fragment (%arg0) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xui32>) { // CHECK-NEXT: return %arg1, %arg1 // CHECK-NEXT: } // CHECK-NEXT: return %[[F2]]#0, %[[F1]], %[[F2]]#1 @@ -140,9 +140,8 @@ func.func @f(%arg0: !mesh_tensor) -> (!mesh_tensor, !mesh_tensor, !mesh_tensor) func.func @identity_function(%arg0: !mesh_tensor) -> !mesh_tensor attributes {"topology"=#mpmd.topology<<"m": <["x"=2]>>>} { - // CHECK-NEXT: %[[F:.*]] = mpmd.fragment (%arg0) (%arg1: tensor<4xui32>) { - // CHECK-NEXT: return %arg1 - // CHECK-NEXT: } - // CHECK-NEXT: return %[[F]] + // Block-arg wrapping is handled by a separate pass. The uniquify pass is a + // no-op here since %arg0 appears only once. + // CHECK-NEXT: return %arg0 func.return %arg0 : !mesh_tensor } diff --git a/shardy/dialect/mpmd/transforms/common/test/wrap_block_arg_returns.mlir b/shardy/dialect/mpmd/transforms/common/test/wrap_block_arg_returns.mlir new file mode 100644 index 000000000..96e04aa80 --- /dev/null +++ b/shardy/dialect/mpmd/transforms/common/test/wrap_block_arg_returns.mlir @@ -0,0 +1,61 @@ +// RUN: mpmd_opt %s -mpmd-wrap-block-arg-returns -split-input-file 2>&1 | FileCheck %s + +!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4xf32>> +!mesh_2_tensor = !mpmd.mesh_tensor<"m2", tensor<4xf32>> + +// CHECK-LABEL: func @no_work_needed +func.func @no_work_needed(%arg0: !mesh_1_tensor, %arg1: !mesh_2_tensor) -> (!mesh_1_tensor, !mesh_2_tensor) attributes { + "topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["x"=2]>>> +} { + // CHECK-NEXT: %[[F1:.*]] = mpmd.fragment + // CHECK: %[[F2:.*]] = mpmd.fragment + // CHECK: return %[[F1]], %[[F2]] + %0 = mpmd.fragment (%arg0) (%arg2: tensor<4xf32>) { + %1 = stablehlo.add %arg2, %arg2 : tensor<4xf32> + mpmd.return %1 : tensor<4xf32> + } : (!mesh_1_tensor) -> !mesh_1_tensor + %1 = mpmd.fragment (%arg1) (%arg2: tensor<4xf32>) { + %1 = stablehlo.add %arg2, %arg2 : tensor<4xf32> + mpmd.return %1 : tensor<4xf32> + } : (!mesh_2_tensor) -> !mesh_2_tensor + return %0, %1 : !mesh_1_tensor, !mesh_2_tensor +} + +// ----- + +!mesh_tensor = !mpmd.mesh_tensor<"m", tensor<4xui32>, sharding=<@mesh, [{"x"}]>> + +module { + +// CHECK-LABEL: func @identity_function +func.func @identity_function(%arg0: !mesh_tensor) -> !mesh_tensor + attributes {"topology"=#mpmd.topology<<"m": <["x"=2]>>>} +{ + // CHECK-NEXT: %[[F:.*]] = mpmd.fragment (%arg0) {mpmd.inferred_by = ["wrap_block_arg_returns"]} (%arg1: tensor<4xui32>) { + // CHECK-NEXT: return %arg1 + // CHECK-NEXT: } + // CHECK-NEXT: return %[[F]] + func.return %arg0 : !mesh_tensor +} + +} + +// ----- + +!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4xf32>> +!mesh_2_tensor = !mpmd.mesh_tensor<"m2", tensor<4xf32>> + +// CHECK-LABEL: func @mixed_returns +func.func @mixed_returns(%arg0: !mesh_1_tensor, %arg1: !mesh_2_tensor) -> (!mesh_1_tensor, !mesh_1_tensor, !mesh_2_tensor) attributes { + "topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["x"=2]>>> +} { + // CHECK-NEXT: %[[F1:.*]] = mpmd.fragment + // CHECK: %[[WRAP1:.*]] = mpmd.fragment (%arg0) {mpmd.inferred_by = ["wrap_block_arg_returns"]} + // CHECK: %[[WRAP2:.*]] = mpmd.fragment (%arg1) {mpmd.inferred_by = ["wrap_block_arg_returns"]} + // CHECK: return %[[WRAP1]], %[[F1]], %[[WRAP2]] + %0 = mpmd.fragment (%arg0) (%arg2: tensor<4xf32>) { + %1 = stablehlo.add %arg2, %arg2 : tensor<4xf32> + mpmd.return %1 : tensor<4xf32> + } : (!mesh_1_tensor) -> !mesh_1_tensor + return %arg0, %0, %arg1 : !mesh_1_tensor, !mesh_1_tensor, !mesh_2_tensor +} diff --git a/shardy/dialect/mpmd/transforms/common/uniquify_function_inputs_outputs.cc b/shardy/dialect/mpmd/transforms/common/uniquify_function_inputs_outputs.cc index 5511c7345..2f58a12c6 100644 --- a/shardy/dialect/mpmd/transforms/common/uniquify_function_inputs_outputs.cc +++ b/shardy/dialect/mpmd/transforms/common/uniquify_function_inputs_outputs.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" @@ -30,7 +31,6 @@ limitations under the License. #include "shardy/dialect/mpmd/ir/utils.h" #include "shardy/dialect/mpmd/transforms/common/passes.h" // IWYU pragma: keep #include "shardy/dialect/sdy/ir/dialect.h" -#include "mlir/IR/MLIRContext.h" namespace mlir::mpmd { @@ -44,16 +44,6 @@ using ValueToReturnIndices = llvm::MapVector>; void CreateReturnFragmentForMesh(StringRef mesh_name, Operation* return_op, ValueToReturnIndices& value_to_return_indices, OpBuilder& builder) { - // We remove any entries that require no work, in order to avoid too many - // checks. - value_to_return_indices.remove_if([](const auto& it) { - if (it.second.size() == 1) { - Value v = it.first; - return !isa(v); - } - return it.second.empty(); - }); - builder.setInsertionPoint(return_op); SmallVector fragment_operands; fragment_operands.reserve(value_to_return_indices.size()); @@ -65,15 +55,14 @@ void CreateReturnFragmentForMesh(StringRef mesh_name, Operation* return_op, cast(value.getType())); } - if (fragment_operands.empty()) { - return; - } - auto loc = return_op->getLoc(); auto fragment_op = FragmentOp::create( builder, loc, fragment_return_types, fragment_operands, /*user_origin=*/ArrayAttr::get(builder.getContext(), {}), /*mesh_name=*/mesh_name, /*stage_id=*/IntegerAttr()); + fragment_op->setAttr( + kInferredByAttr, + builder.getArrayAttr({builder.getStringAttr("uniquify")})); Block& fragment_block = fragment_op.getRegion().emplaceBlock(); SmallVector returned_values; @@ -88,8 +77,7 @@ void CreateReturnFragmentForMesh(StringRef mesh_name, Operation* return_op, returned_values.insert( returned_values.end(), return_indices.size(), fragment_block.addArgument( - GetGlobalTensorTypeFromMeshType(value, mesh_attr), - value.getLoc())); + GetGlobalTensorTypeFromMeshType(value, mesh_attr), value.getLoc())); for (int64_t index : return_indices) { return_op->setOperand(index, @@ -150,7 +138,7 @@ class UniquifyFunctionInputOutputsPass using UniquifyFunctionInputsOutputsPassBase:: UniquifyFunctionInputsOutputsPassBase; - private: + protected: void runOnFunc(func::FuncOp func_op) override { if (!IsMpmdFunction(func_op)) { // This is not the main function. Do nothing. @@ -179,6 +167,9 @@ class UniquifyFunctionInputOutputsPass OpBuilder builder(&getContext()); for (auto& [mesh_name, value_to_return_indices] : value_to_return_indices_per_mesh) { + // Only keep entries that are returned more than once (need uniquifying). + value_to_return_indices.remove_if( + [](const auto& it) { return it.second.size() <= 1; }); CreateReturnFragmentForMesh(mesh_name, return_op, value_to_return_indices, builder); } diff --git a/shardy/dialect/mpmd/transforms/common/wrap_block_arg_returns.cc b/shardy/dialect/mpmd/transforms/common/wrap_block_arg_returns.cc new file mode 100644 index 000000000..44182be13 --- /dev/null +++ b/shardy/dialect/mpmd/transforms/common/wrap_block_arg_returns.cc @@ -0,0 +1,127 @@ +/* Copyright 2025 The MPMD Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "shardy/common/logging.h" +#include "shardy/dialect/mpmd/ir/dialect.h" +#include "shardy/dialect/mpmd/ir/utils.h" +#include "shardy/dialect/mpmd/transforms/common/passes.h" // IWYU pragma: keep +#include "shardy/dialect/sdy/ir/dialect.h" + +namespace mlir::mpmd { + +#define GEN_PASS_DEF_WRAPBLOCKARGRETURNSPASS +#include "shardy/dialect/mpmd/transforms/common/passes.h.inc" + +namespace { + +using ValueToReturnIndices = llvm::MapVector>; + +// Creates an identity fragment for block arguments returned by the function +// on a given mesh. Each block argument gets a single fragment input and one +// result per return position it appears in. +void WrapBlockArgsForMesh(StringRef mesh_name, Operation* return_op, + ValueToReturnIndices& value_to_return_indices, + OpBuilder& builder) { + if (value_to_return_indices.empty()) { + return; + } + + if (value_to_return_indices.empty()) { + return; + } + + builder.setInsertionPoint(return_op); + SmallVector fragment_operands; + fragment_operands.reserve(value_to_return_indices.size()); + SmallVector fragment_return_types; + for (const auto& [value, return_indices] : value_to_return_indices) { + fragment_operands.push_back(value); + fragment_return_types.insert(fragment_return_types.end(), + return_indices.size(), + cast(value.getType())); + } + + auto loc = return_op->getLoc(); + auto fragment_op = FragmentOp::create( + builder, loc, fragment_return_types, fragment_operands, + /*user_origin=*/ArrayAttr::get(builder.getContext(), {}), + /*mesh_name=*/mesh_name, /*stage_id=*/IntegerAttr()); + fragment_op->setAttr( + kInferredByAttr, + builder.getArrayAttr({builder.getStringAttr("wrap_block_arg_returns")})); + Block& fragment_block = fragment_op.getRegion().emplaceBlock(); + + SmallVector returned_values; + returned_values.reserve(fragment_return_types.size()); + int fragment_result_index = 0; + sdy::MeshAttr mesh_attr = GetMeshOrFail(fragment_op, mesh_name); + for (const auto& [value, return_indices] : value_to_return_indices) { + returned_values.insert( + returned_values.end(), return_indices.size(), + fragment_block.addArgument( + GetGlobalTensorTypeFromMeshType(value, mesh_attr), value.getLoc())); + + for (int64_t index : return_indices) { + return_op->setOperand(index, + fragment_op->getResult(fragment_result_index++)); + } + } + auto block_builder = OpBuilder::atBlockEnd(&fragment_block); + ReturnOp::create(block_builder, loc, returned_values); +} + +class WrapBlockArgReturnsPass + : public impl::WrapBlockArgReturnsPassBase { + using WrapBlockArgReturnsPassBase::WrapBlockArgReturnsPassBase; + + protected: + void runOnFunc(func::FuncOp func_op) override { + if (!IsMpmdFunction(func_op)) { + return; + } + + Operation* return_op = func_op.getBody().front().getTerminator(); + llvm::MapVector + value_to_return_indices_per_mesh; + for (OpOperand& operand : return_op->getOpOperands()) { + if (!isa(operand.get())) continue; + auto mesh_type = dyn_cast(operand.get().getType()); + SDY_CHECK(mesh_type); + StringRef mesh_name = mesh_type.getMeshName(); + value_to_return_indices_per_mesh[mesh_name][operand.get()].push_back( + operand.getOperandNumber()); + } + + OpBuilder builder(&getContext()); + for (auto& [mesh_name, value_to_return_indices] : + value_to_return_indices_per_mesh) { + WrapBlockArgsForMesh(mesh_name, return_op, value_to_return_indices, + builder); + } + } +}; + +} // namespace +} // namespace mlir::mpmd diff --git a/shardy/dialect/mpmd/transforms/export/export_pipeline.cc b/shardy/dialect/mpmd/transforms/export/export_pipeline.cc index 6cbf2f8f5..1e354f2e1 100644 --- a/shardy/dialect/mpmd/transforms/export/export_pipeline.cc +++ b/shardy/dialect/mpmd/transforms/export/export_pipeline.cc @@ -84,6 +84,7 @@ void addExportPipeline(OpPassManager& pm, const ExportOptions& options) { // Must be applied after the last -mpmd-fragment-dedup, as it may add // duplicated fragment results and after -canonicalize, as it may add // identity fragments, which would be canonicalized away. + pm.addNestedPass(createWrapBlockArgReturnsPass()); pm.addNestedPass(createUniquifyFunctionInputsOutputsPass()); // The fragments created by the pass above maybe slowdown compilation (more diff --git a/shardy/dialect/mpmd/transforms/import/infer_mesh_assignment.cc b/shardy/dialect/mpmd/transforms/import/infer_mesh_assignment.cc index 88aafe9af..024d773f3 100644 --- a/shardy/dialect/mpmd/transforms/import/infer_mesh_assignment.cc +++ b/shardy/dialect/mpmd/transforms/import/infer_mesh_assignment.cc @@ -29,6 +29,7 @@ limitations under the License. #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -400,7 +401,7 @@ FragmentOp CreateReduceFragment(ArrayRef mesh_tensors, StringRef mesh_name, ReductionType reduction_type, RewriterBase& rewriter) { - return FragmentOp::createMeshFragmentWithGlobalBody( + FragmentOp fragment_op = FragmentOp::createMeshFragmentWithGlobalBody( mesh_tensors.front().getLoc(), /*user_origin=*/{}, mesh_name, mesh_tensors, mesh_tensors.front().getType(), rewriter, [reduction_type](ArrayRef args, OpBuilder& block_builder) { @@ -413,6 +414,10 @@ FragmentOp CreateReduceFragment(ArrayRef mesh_tensors, } return SmallVector({accumulator}); }); + fragment_op->setAttr(kInferredByAttr, + rewriter.getArrayAttr({rewriter.getStringAttr( + "infer_mesh_convert_reduce_ops")})); + return fragment_op; } // This pattern lowers mpmd.reduce to reductions and transfers. @@ -1422,8 +1427,9 @@ class InferMeshAssignMeshForFuncLeavesPass if (MeshesWithOrigins src_set = GetSrcSet(op)) { if (src_set.empty()) { if (!inferTransfers) { - op->emitError("src_set must not be empty for this op. Try setting " - "`mpmd_infer_transfers` in the partitioning options."); + op->emitError( + "src_set must not be empty for this op. Try setting " + "`mpmd_infer_transfers` in the partitioning options."); // In this case, we have to stop here, or otherwise we would crash // below. return signalPassFailure(); @@ -1491,7 +1497,8 @@ class InferMeshAssignMeshForFuncLeavesPass return IsMeshBeforeOtherMesh(a, b); }); for (StringRef mesh_name : mesh_names) { - WrapOpWithFragment(op, mesh_name, rewriter); + WrapOpWithFragment(op, mesh_name, rewriter, + "assign_mesh_for_func_leaves"); if (isPure(op)) { // For pure ops, we only need to wrap it in a fragment once. But for // non-pure ops, we need to keep them associated with each src. @@ -2281,7 +2288,7 @@ void WrapBasedOnAssignUsers(Operation* op, RewriterBase& rewriter) { }); for (StringRef mesh_name : user_mesh_types_vec) { WrapOpWithFragment( - op, mesh_name, rewriter, + op, mesh_name, rewriter, "rewrite_using_analysis", /*should_replace_use=*/[&mesh_name](OpOperand& use) { if (auto assign_user = dyn_cast(use.getOwner())) { return assign_user.getType().getMeshName() == mesh_name; @@ -2313,8 +2320,8 @@ void AssignOpBasedOnConsumers(Operation* op, const int max_clones, // Non-tensor results (e.g., tokens) bypass AssignOp, so their users // will not be AssignOps. Skip them. if (!assign_op) { - SDY_CHECK(llvm::none_of(op->getResultTypes(), - llvm::IsaPred)); + SDY_CHECK( + llvm::none_of(op->getResultTypes(), llvm::IsaPred)); continue; } for (Operation* assign_user : assign_op->getUsers()) { diff --git a/shardy/dialect/mpmd/transforms/import/introduce_transfers.cc b/shardy/dialect/mpmd/transforms/import/introduce_transfers.cc index 8e0e9392c..b46ba505a 100644 --- a/shardy/dialect/mpmd/transforms/import/introduce_transfers.cc +++ b/shardy/dialect/mpmd/transforms/import/introduce_transfers.cc @@ -200,7 +200,8 @@ class PushAssignBackwardThroughAdd : public OpRewritePattern { if (failed(IsAddOfUnassigns(meshless_add, assign, rewriter))) { return failure(); } - WrapOpWithFragment(meshless_add, assign.getType().getMeshName(), rewriter); + WrapOpWithFragment(meshless_add, assign.getType().getMeshName(), rewriter, + "introduce_transfers"); return success(); } }; diff --git a/shardy/dialect/mpmd/transforms/import/test/infer_mesh_assign_mesh_for_func_op_leaves.mlir b/shardy/dialect/mpmd/transforms/import/test/infer_mesh_assign_mesh_for_func_op_leaves.mlir index 9a4bd8427..3750d0f86 100644 --- a/shardy/dialect/mpmd/transforms/import/test/infer_mesh_assign_mesh_for_func_op_leaves.mlir +++ b/shardy/dialect/mpmd/transforms/import/test/infer_mesh_assign_mesh_for_func_op_leaves.mlir @@ -205,7 +205,7 @@ func.func @op_with_no_result_without_src_set(%arg0: tensor<4x8xf32>) -> tensor<4 <"m1": <["x"=2]>>, <"m2": <["y"=2]>>, <"m3": <["z"=2]>> >} { // CHECK-NEXT: %[[ASSIGN:.*]] = mpmd.assign %arg0 -// CHECK-NEXT: mpmd.fragment (%[[ASSIGN]]) (%arg1 +// CHECK-NEXT: mpmd.fragment (%[[ASSIGN]]) {mpmd.inferred_by = ["assign_mesh_for_func_leaves"]} (%arg1 // CHECK-NEXT: sdy.sharding_group %arg1 // CHECK-NEXT: mpmd.return // CHECK-NEXT: } @@ -221,7 +221,7 @@ func.func @op_with_no_result_with_src_set(%arg0: tensor<4x8xf32>) -> tensor<4x8x <"m1": <["x"=2]>>, <"m2": <["y"=2]>>, <"m3": <["z"=2]>> >} { // CHECK-NEXT: %[[ASSIGN:.*]] = mpmd.assign %arg0 -// CHECK-NEXT: mpmd.fragment (%[[ASSIGN]]) (%arg1 +// CHECK-NEXT: mpmd.fragment (%[[ASSIGN]]) {mpmd.inferred_by = ["assign_mesh_for_func_leaves"]} (%arg1 // CHECK-NEXT: sdy.sharding_group %arg1 // CHECK-NEXT: mpmd.return // CHECK-NEXT: } @@ -237,12 +237,12 @@ func.func @op_with_no_result_with_multi_src_set(%arg0: tensor<4x8xf32>) -> tenso <"m1": <["x"=2]>>, <"m2": <["y"=2]>>, <"m3": <["z"=2]>> >} { // CHECK-NEXT: %[[ASSIGN_2:.*]] = mpmd.assign %arg0 -// CHECK-NEXT: mpmd.fragment (%[[ASSIGN_2]]) (%arg1 +// CHECK-NEXT: mpmd.fragment (%[[ASSIGN_2]]) {mpmd.inferred_by = ["assign_mesh_for_func_leaves"]} (%arg1 // CHECK-NEXT: sdy.sharding_group %arg1 // CHECK-NEXT: mpmd.return // CHECK-NEXT: } // CHECK-NEXT: %[[ASSIGN_3:.*]] = mpmd.assign %arg0 -// CHECK-NEXT: mpmd.fragment (%[[ASSIGN_3]]) (%arg1 +// CHECK-NEXT: mpmd.fragment (%[[ASSIGN_3]]) {mpmd.inferred_by = ["assign_mesh_for_func_leaves"]} (%arg1 // CHECK-NEXT: sdy.sharding_group %arg1 // CHECK-NEXT: mpmd.return // CHECK-NEXT: } diff --git a/shardy/dialect/mpmd/transforms/import/test/infer_mesh_pipeline.mlir b/shardy/dialect/mpmd/transforms/import/test/infer_mesh_pipeline.mlir index fa4a49453..86e284428 100644 --- a/shardy/dialect/mpmd/transforms/import/test/infer_mesh_pipeline.mlir +++ b/shardy/dialect/mpmd/transforms/import/test/infer_mesh_pipeline.mlir @@ -251,11 +251,11 @@ func.func @assign_of_non_scalar_const() <"m1": <["x"=2]>>, <"m2": <["y"=2]>> >} { -// CHECK-NEXT: %[[INFERRED_1:.*]] = mpmd.fragment () () { +// CHECK-NEXT: %[[INFERRED_1:.*]] = mpmd.fragment () {mpmd.inferred_by = ["rewrite_using_analysis"]} () { // CHECK-NEXT: %[[CONST_1:.*]] = stablehlo.constant dense<1> : tensor<5x5xui32> // CHECK-NEXT: mpmd.return %[[CONST_1]] // CHECK-NEXT: } -// CHECK-NEXT: %[[INFERRED_2:.*]] = mpmd.fragment () () { +// CHECK-NEXT: %[[INFERRED_2:.*]] = mpmd.fragment () {mpmd.inferred_by = ["rewrite_using_analysis"]} () { // CHECK-NEXT: %[[CONST_2:.*]] = stablehlo.constant dense<1> : tensor<5x5xui32> // CHECK-NEXT: mpmd.return %[[CONST_2]] // CHECK-NEXT: } @@ -408,11 +408,11 @@ func.func @op_with_no_results(%arg0: tensor<4x16xf32>) <"m1": <["x"=2]>>, <"m2": <["y"=2]>> >} { -// CHECK-NEXT: %[[INFERRED_1:.*]] = mpmd.fragment (%arg0) +// CHECK-NEXT: %[[INFERRED_1:.*]] = mpmd.fragment (%arg0) {mpmd.inferred_by = ["rewrite_using_analysis"]} // CHECK-NEXT: stablehlo.add %arg1, %arg1 // CHECK-NEXT: mpmd.return // CHECK-NEXT: } -// CHECK: mpmd.fragment (%[[INFERRED_1]]) (%arg1 +// CHECK: mpmd.fragment (%[[INFERRED_1]]) {mpmd.inferred_by = ["assign_mesh_for_func_leaves"]} (%arg1 // CHECK-NEXT: sdy.sharding_group %arg1 // CHECK-NEXT: mpmd.return // CHECK-NEXT: } @@ -432,19 +432,19 @@ func.func @op_with_no_results_multiple_meshes(%arg0: !mesh_2_tensor_4_16_f32) <"m2": <["y"=2]>> >} { // CHECK-NEXT: %[[TRANSFER:.*]] = mpmd.transfer %arg0 -// CHECK-NEXT: %[[INFERRED_1:.*]] = mpmd.fragment (%[[TRANSFER]]) +// CHECK-NEXT: %[[INFERRED_1:.*]] = mpmd.fragment (%[[TRANSFER]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} // CHECK-NEXT: stablehlo.add %arg1, %arg1 // CHECK-NEXT: mpmd.return // CHECK-NEXT: } -// CHECK-NEXT: %[[INFERRED_2:.*]] = mpmd.fragment (%arg0) +// CHECK-NEXT: %[[INFERRED_2:.*]] = mpmd.fragment (%arg0) {mpmd.inferred_by = ["rewrite_using_analysis"]} // CHECK-NEXT: stablehlo.add %arg1, %arg1 // CHECK-NEXT: mpmd.return // CHECK-NEXT: } -// CHECK-NEXT: mpmd.fragment (%[[INFERRED_1]]) (%arg1 +// CHECK-NEXT: mpmd.fragment (%[[INFERRED_1]]) {mpmd.inferred_by = ["assign_mesh_for_func_leaves"]} (%arg1 // CHECK-NEXT: sdy.sharding_group %arg1 // CHECK-NEXT: mpmd.return // CHECK-NEXT: } -// CHECK-NEXT: mpmd.fragment (%[[INFERRED_2]]) (%arg1 +// CHECK-NEXT: mpmd.fragment (%[[INFERRED_2]]) {mpmd.inferred_by = ["assign_mesh_for_func_leaves"]} (%arg1 // CHECK-NEXT: sdy.sharding_group %arg1 // CHECK-NEXT: mpmd.return // CHECK-NEXT: } @@ -587,7 +587,7 @@ func.func @multiple_meshes_complex(%arg0: tensor<4x8xf32>, <"m1": <["x"=2]>>, <"m2": <["y"=2]>> >} { -// CHECK-NEXT: %[[INFERRED_1:.*]] = mpmd.fragment (%arg2, %arg4) (%arg5: tensor<16x8xf32>, %arg6: tensor<16x8xf32>) { +// CHECK-NEXT: %[[INFERRED_1:.*]] = mpmd.fragment (%arg2, %arg4) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg5: tensor<16x8xf32>, %arg6: tensor<16x8xf32>) { // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %arg5, %arg6 : tensor<16x8xf32> // CHECK-NEXT: mpmd.return %[[ADD_1]] : tensor<16x8xf32> // CHECK-NEXT: } : (!mpmd.mesh_tensor<"m2", tensor<16x8xf32>>, !mpmd.mesh_tensor<"m2", tensor<16x8xf32>>) -> !mpmd.mesh_tensor<"m2", tensor<16x8xf32>> @@ -601,7 +601,7 @@ func.func @multiple_meshes_complex(%arg0: tensor<4x8xf32>, // CHECK-NEXT: %[[DOT_2:.*]] = stablehlo.dot %arg5, %[[ADD_2]] : (tensor<4x16xf32>, tensor<16x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: mpmd.return %[[DOT_2]] : tensor<4x8xf32> // CHECK-NEXT: } : (!mpmd.mesh_tensor<"m2", tensor<4x16xf32>>, !mpmd.mesh_tensor<"m2", tensor<16x8xf32>>) -> !mpmd.mesh_tensor<"m2", tensor<4x8xf32>> -// CHECK-NEXT: %[[INFERRED_2:.*]] = mpmd.fragment (%[[FRAGMENT_1]], %arg3) (%arg5: tensor<4x16xf32>, %arg6: tensor<4x16xf32>) { +// CHECK-NEXT: %[[INFERRED_2:.*]] = mpmd.fragment (%[[FRAGMENT_1]], %arg3) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg5: tensor<4x16xf32>, %arg6: tensor<4x16xf32>) { // CHECK-NEXT: %[[ADD_3:.*]] = stablehlo.add %arg5, %arg6 : tensor<4x16xf32> // CHECK-NEXT: %[[ADD_4:.*]] = stablehlo.add %[[ADD_3]], %[[ADD_3]] : tensor<4x16xf32> // CHECK-NEXT: mpmd.return %[[ADD_4]] : tensor<4x16xf32> diff --git a/shardy/dialect/mpmd/transforms/import/test/infer_mesh_rewrite_using_analysis.mlir b/shardy/dialect/mpmd/transforms/import/test/infer_mesh_rewrite_using_analysis.mlir index a7707d2d0..1cb8c112b 100644 --- a/shardy/dialect/mpmd/transforms/import/test/infer_mesh_rewrite_using_analysis.mlir +++ b/shardy/dialect/mpmd/transforms/import/test/infer_mesh_rewrite_using_analysis.mlir @@ -14,7 +14,7 @@ func.func @simple_rewrite(%arg0: tensor<4x8xf32> {mpmd.src_set = #mpmd.meshes_with_origins<"m1">, mpmd.use_set = #mpmd.meshes_with_origins<"m1">}) -> !mesh_1_tensor_4_8_f32 attributes {"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["y"=2]>>>} { // CHECK-NEXT: %[[ASSIGN:.*]] = mpmd.assign %arg0 : (tensor<4x8xf32>) -> !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> - // CHECK-NEXT: %[[ADD_FRAG:.*]] = mpmd.fragment (%[[ASSIGN]]) (%arg1: tensor<4x8xf32>) { + // CHECK-NEXT: %[[ADD_FRAG:.*]] = mpmd.fragment (%[[ASSIGN]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg1: tensor<4x8xf32>) { // CHECK-NEXT: %[[CONST:.*]] = stablehlo.constant dense<1.000000e+00> // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg1, %[[CONST]] // CHECK-NEXT: mpmd.return %[[ADD]] @@ -34,7 +34,7 @@ func.func @rewrite_with_duplication(%arg0: tensor<4x8xf32> {mpmd.src_set = #mpmd // CHECK-DAG: %[[ASSIGN_ARG0_1:.*]] = mpmd.assign %arg0 : (tensor<4x8xf32>) -> !mpmd.mesh_tensor<"m1", - // CHECK-DAG: %[[ADD_FRAG_1:.*]] = mpmd.fragment (%[[ASSIGN_ARG0_1]]) (%arg1: tensor<4x8xf32>) { + // CHECK-DAG: %[[ADD_FRAG_1:.*]] = mpmd.fragment (%[[ASSIGN_ARG0_1]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg1: tensor<4x8xf32>) { // CHECK-NEXT: %[[CONST_1:.*]] = stablehlo.constant dense<1.000000e+00> // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %arg1, %[[CONST_1]] // CHECK-NEXT: mpmd.return %[[ADD_1]] @@ -42,7 +42,7 @@ func.func @rewrite_with_duplication(%arg0: tensor<4x8xf32> {mpmd.src_set = #mpmd // CHECK-DAG: %[[UNASSIGN_1:.*]] = mpmd.unassign %[[ADD_FRAG_1]] : (!mpmd.mesh_tensor<"m1", tensor<4x8xf32>>) -> tensor<4x8xf32> // CHECK-DAG: %[[ASSIGN_ARG0_2:.*]] = mpmd.assign %arg0 : (tensor<4x8xf32>) -> !mpmd.mesh_tensor<"m2", - // CHECK-DAG: %[[ADD_FRAG_2:.*]] = mpmd.fragment (%[[ASSIGN_ARG0_2]]) (%arg1: tensor<4x8xf32>) { + // CHECK-DAG: %[[ADD_FRAG_2:.*]] = mpmd.fragment (%[[ASSIGN_ARG0_2]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg1: tensor<4x8xf32>) { // CHECK-NEXT: %[[CONST_2:.*]] = stablehlo.constant dense<1.000000e+00> // CHECK-NEXT: %[[ADD_2:.*]] = stablehlo.add %arg1, %[[CONST_2]] // CHECK-NEXT: mpmd.return %[[ADD_2]] @@ -66,14 +66,14 @@ func.func @call_op(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> (!mesh_1_t attributes {"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["y"=2]>>>} { // CHECK-NEXT: %[[ARG0_ASSIGN:.*]] = mpmd.assign %arg0 : (tensor<4x8xf32>) -> !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> - // CHECK-NEXT: %[[ADD_FRAG0:.*]] = mpmd.fragment (%[[ARG0_ASSIGN]]) (%arg2 + // CHECK-NEXT: %[[ADD_FRAG0:.*]] = mpmd.fragment (%[[ARG0_ASSIGN]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg2 // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg2, %arg2 // CHECK-NEXT: mpmd.return %[[ADD]] // CHECK-NEXT: } // CHECK-NEXT: %[[F0_UNASSIGN:.*]] = mpmd.unassign %1 : (!mpmd.mesh_tensor<"m1", tensor<4x8xf32>>) -> tensor<4x8xf32> // CHECK-NEXT: %[[ARG1_ASSIGN:.*]] = mpmd.assign %arg1 : (tensor<4x8xf32>) -> !mpmd.mesh_tensor<"m2", tensor<4x8xf32>> - // CHECK-NEXT: %[[ADD_FRAG1:.*]] = mpmd.fragment (%[[ARG1_ASSIGN]]) (%arg2 + // CHECK-NEXT: %[[ADD_FRAG1:.*]] = mpmd.fragment (%[[ARG1_ASSIGN]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg2 // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg2, %arg2 // CHECK-NEXT: mpmd.return %[[ADD]] // CHECK-NEXT: } @@ -113,12 +113,12 @@ func.func private @call_op_f( // CHECK-DAG: %[[ASSIGN_ARG0:.*]] = mpmd.assign %[[UNASSIGN_ARG0]] : (tensor<4x8xf32>) -> !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> // CHECK-DAG: %[[ASSIGN_ARG1:.*]] = mpmd.assign %[[UNASSIGN_ARG1]] : (tensor<4x8xf32>) -> !mpmd.mesh_tensor<"m2", tensor<4x8xf32>> - // CHECK-DAG: %[[ADD_FRAG0:.*]] = mpmd.fragment (%[[ASSIGN_ARG0]]) (%arg2 + // CHECK-DAG: %[[ADD_FRAG0:.*]] = mpmd.fragment (%[[ASSIGN_ARG0]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg2 // CHECK-DAG: %[[ADD:.*]] = stablehlo.add %arg2, %arg2 // CHECK-DAG: mpmd.return %[[ADD]] // CHECK-DAG: } - // CHECK-DAG: %[[ADD_FRAG1:.*]] = mpmd.fragment (%[[ASSIGN_ARG1]]) (%arg2 + // CHECK-DAG: %[[ADD_FRAG1:.*]] = mpmd.fragment (%[[ASSIGN_ARG1]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg2 // CHECK-DAG: %[[ADD:.*]] = stablehlo.add %arg2, %arg2 // CHECK-DAG: mpmd.return %[[ADD]] // CHECK-DAG: } @@ -140,13 +140,13 @@ func.func @call_op_multiple_calls(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32> attributes {"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["y"=2]>>>} { // CHECK-NEXT: %[[ARG0_ASSIGN:.*]] = mpmd.assign %arg0 : (tensor<4x8xf32>) -> !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> - // CHECK-NEXT: %[[ADD_FRAG0:.*]] = mpmd.fragment (%[[ARG0_ASSIGN]]) (%arg2 + // CHECK-NEXT: %[[ADD_FRAG0:.*]] = mpmd.fragment (%[[ARG0_ASSIGN]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg2 // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg2, %arg2 // CHECK-NEXT: mpmd.return %[[ADD]] // CHECK-NEXT: } // CHECK-NEXT: %[[F0_UNASSIGN:.*]] = mpmd.unassign %[[ADD_FRAG0]] : (!mpmd.mesh_tensor<"m1", tensor<4x8xf32>>) -> tensor<4x8xf32> // CHECK-NEXT: %[[ARG1_ASSIGN:.*]] = mpmd.assign %arg1 : (tensor<4x8xf32>) -> !mpmd.mesh_tensor<"m2", tensor<4x8xf32>> - // CHECK-NEXT: %[[ADD_FRAG1:.*]] = mpmd.fragment (%[[ARG1_ASSIGN]]) (%arg2 + // CHECK-NEXT: %[[ADD_FRAG1:.*]] = mpmd.fragment (%[[ARG1_ASSIGN]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg2 // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg2, %arg2 // CHECK-NEXT: mpmd.return %[[ADD]] // CHECK-NEXT: } @@ -192,11 +192,11 @@ func.func private @call_op_multiple_calls_f( ) attributes {"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["y"=2]>>>} { - // CHECK: %[[ADD_FRAG0:.*]] = mpmd.fragment ({{.*}}) (%arg2 + // CHECK: %[[ADD_FRAG0:.*]] = mpmd.fragment ({{.*}}) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg2 // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg2, %arg2 // CHECK-NEXT: mpmd.return %[[ADD]] // CHECK-NEXT: } - // CHECK: %[[ADD_FRAG1:.*]] = mpmd.fragment ({{.*}}) (%arg2 + // CHECK: %[[ADD_FRAG1:.*]] = mpmd.fragment ({{.*}}) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg2 // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg2, %arg2 // CHECK-NEXT: mpmd.return %[[ADD]] // CHECK-NEXT: } @@ -211,14 +211,14 @@ func.func @call_op_unused_output(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) attributes {"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["y"=2]>>>} { // CHECK-NEXT: %[[ARG0_ASSIGN:.*]] = mpmd.assign %arg0 : (tensor<4x8xf32>) -> !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> - // CHECK-NEXT: %[[ADD_FRAG0:.*]] = mpmd.fragment (%[[ARG0_ASSIGN]]) (%arg2 + // CHECK-NEXT: %[[ADD_FRAG0:.*]] = mpmd.fragment (%[[ARG0_ASSIGN]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg2 // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg2, %arg2 // CHECK-NEXT: mpmd.return %[[ADD]] // CHECK-NEXT: } // CHECK-NEXT: %[[F0_UNASSIGN:.*]] = mpmd.unassign %1 : (!mpmd.mesh_tensor<"m1", tensor<4x8xf32>>) -> tensor<4x8xf32> // CHECK-NEXT: %[[ARG1_ASSIGN:.*]] = mpmd.assign %arg1 : (tensor<4x8xf32>) -> !mpmd.mesh_tensor<"m2", tensor<4x8xf32>> - // CHECK-NEXT: %[[ADD_FRAG1:.*]] = mpmd.fragment (%[[ARG1_ASSIGN]]) (%arg2 + // CHECK-NEXT: %[[ADD_FRAG1:.*]] = mpmd.fragment (%[[ARG1_ASSIGN]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg2 // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg2, %arg2 // CHECK-NEXT: mpmd.return %[[ADD]] // CHECK-NEXT: } @@ -255,8 +255,8 @@ func.func private @call_op_unused_output_f( // CHECK-DAG: %[[ASSIGN_ARG0:.*]] = mpmd.assign %[[UNASSIGN_ARG0]] : (tensor<4x8xf32>) -> !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> // CHECK-DAG: %[[ASSIGN_ARG1:.*]] = mpmd.assign %[[UNASSIGN_ARG1]] : (tensor<4x8xf32>) -> !mpmd.mesh_tensor<"m2", tensor<4x8xf32>> - // CHECK-DAG: %[[ADD_FRAG0:.*]] = mpmd.fragment (%[[ASSIGN_ARG0]]) (%arg2 - // CHECK-DAG: %[[ADD_FRAG1:.*]] = mpmd.fragment (%[[ASSIGN_ARG1]]) (%arg2 + // CHECK-DAG: %[[ADD_FRAG0:.*]] = mpmd.fragment (%[[ASSIGN_ARG0]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg2 + // CHECK-DAG: %[[ADD_FRAG1:.*]] = mpmd.fragment (%[[ASSIGN_ARG1]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg2 // CHECK-DAG: %[[F0_UNASSIGN:.*]] = mpmd.unassign %[[ADD_FRAG0]] : (!mpmd.mesh_tensor<"m1", tensor<4x8xf32>>) -> tensor<4x8xf32> // CHECK-DAG: %[[F1_UNASSIGN:.*]] = mpmd.unassign %[[ADD_FRAG1]] : (!mpmd.mesh_tensor<"m2", tensor<4x8xf32>>) -> tensor<4x8xf32> @@ -388,17 +388,17 @@ func.func @call_op_with_multi_result_assignment(%arg0: tensor<4x8xf32>, %arg1: t { // CHECK-NEXT: %[[ARG0_ASSIGN_0:.*]] = mpmd.assign %arg0 : (tensor<4x8xf32>) -> !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> - // CHECK-NEXT: %[[MULT_FRAG0:.*]] = mpmd.fragment (%[[ARG0_ASSIGN_0]]) + // CHECK-NEXT: %[[MULT_FRAG0:.*]] = mpmd.fragment (%[[ARG0_ASSIGN_0]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} // CHECK-NEXT: stablehlo.multiply // CHECK: %[[MULT_FRAG0_UNASSIGN:.*]] = mpmd.unassign %[[MULT_FRAG0]] : (!mpmd.mesh_tensor<"m1", tensor<4x8xf32>>) -> tensor<4x8xf32> // CHECK-NEXT: %[[ARG0_ASSIGN_1:.*]] = mpmd.assign %arg0 : (tensor<4x8xf32>) -> !mpmd.mesh_tensor<"m2", tensor<4x8xf32>> - // CHECK-NEXT: %[[MULT_FRAG1:.*]] = mpmd.fragment (%[[ARG0_ASSIGN_1]]) (%arg2: tensor<4x8xf32>) { + // CHECK-NEXT: %[[MULT_FRAG1:.*]] = mpmd.fragment (%[[ARG0_ASSIGN_1]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg2: tensor<4x8xf32>) { // CHECK-NEXT: stablehlo.multiply // CHECK: %[[MULT_FRAG1_UNASSIGN:.*]] = mpmd.unassign %[[MULT_FRAG1]] : (!mpmd.mesh_tensor<"m2", tensor<4x8xf32>>) -> tensor<4x8xf32> // CHECK-NEXT: %[[ARG1_ASSIGN:.*]] = mpmd.assign %arg1 : (tensor<4x8xf32>) -> !mpmd.mesh_tensor<"m2", tensor<4x8xf32>> - // CHECK-NEXT: %[[ADD_FRAG:.*]] = mpmd.fragment (%[[ARG1_ASSIGN]]) (%arg2: tensor<4x8xf32>) { + // CHECK-NEXT: %[[ADD_FRAG:.*]] = mpmd.fragment (%[[ARG1_ASSIGN]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg2: tensor<4x8xf32>) { // CHECK-NEXT: stablehlo.add // CHECK: %[[ADD_FRAG_UNASSIGN:.*]] = mpmd.unassign %[[ADD_FRAG]] : (!mpmd.mesh_tensor<"m2", tensor<4x8xf32>>) -> tensor<4x8xf32> @@ -472,14 +472,14 @@ func.func @call_op_noop_return(%arg0: tensor<4x8xf32>) -> (!mesh_1_tensor_4_8_f3 attributes {"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["y"=2]>>>} { // CHECK-NEXT: %[[ARG0_ASSIGN_0:.*]] = mpmd.assign %arg0 : (tensor<4x8xf32>) -> !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> -// CHECK-NEXT: %[[ADD_FRAG0:.*]] = mpmd.fragment (%[[ARG0_ASSIGN_0]]) (%arg1 +// CHECK-NEXT: %[[ADD_FRAG0:.*]] = mpmd.fragment (%[[ARG0_ASSIGN_0]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg1 // CHECK-NEXT: stablehlo.add %arg1, %arg1 // CHECK-NEXT: mpmd.return // CHECK-NEXT: } // CHECK-NEXT: %[[F0_UNASSIGN:.*]] = mpmd.unassign %[[ADD_FRAG0]] : (!mpmd.mesh_tensor<"m1", tensor<4x8xf32>>) -> tensor<4x8xf32> // CHECK-NEXT: %[[ARG0_ASSIGN_1:.*]] = mpmd.assign %arg0 : (tensor<4x8xf32>) -> !mpmd.mesh_tensor<"m2", tensor<4x8xf32>> -// CHECK-NEXT: %[[ADD_FRAG1:.*]] = mpmd.fragment (%[[ARG0_ASSIGN_1]]) (%arg1 +// CHECK-NEXT: %[[ADD_FRAG1:.*]] = mpmd.fragment (%[[ARG0_ASSIGN_1]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg1 // CHECK-NEXT: stablehlo.add %arg1, %arg1 // CHECK-NEXT: mpmd.return // CHECK-NEXT: } @@ -660,7 +660,7 @@ func.func @single_fragment_consumer_but_used_by_return(%arg0: tensor<4x8xf32>) - attributes {"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["y"=2]>>>} { // The constant cannot be inlined into the consumer fragment because it is // used by the return statement. - // CHECK-NEXT: mpmd.fragment () () + // CHECK-NEXT: mpmd.fragment () {mpmd.inferred_by = ["rewrite_using_analysis"]} () // CHECK-NEXT: constant // CHECK: mpmd.fragment %0 = stablehlo.constant dense<1.0> : tensor<4x8xf32> @@ -717,7 +717,7 @@ func.func @op_operands_used_by_consumer(%arg0: !mesh_1_tensor_4_8_f32) -> !mesh_ func.func @fori_loop(%arg0: tensor {mpmd.src_set = #mpmd.meshes_with_origins<"m1">, mpmd.use_set = #mpmd.meshes_with_origins<"m1">}) -> (tensor, tensor) attributes {topology = #mpmd.topology<<"m1" : <["x"=2]>>, <"m2" : <["y"=2]>>>} { // CHECK-NEXT: %[[ARG0_ASSIGN:.*]] = mpmd.assign %arg0 : (tensor) -> !mpmd.mesh_tensor<"m1", tensor> - // CHECK-NEXT: %[[ADD_FRAG0:.*]] = mpmd.fragment (%[[ARG0_ASSIGN]]) (%arg1: tensor) { + // CHECK-NEXT: %[[ADD_FRAG0:.*]] = mpmd.fragment (%[[ARG0_ASSIGN]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg1: tensor) { // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg1, %arg1 : tensor // CHECK-NEXT: mpmd.return %[[ADD]] : tensor // CHECK-NEXT: } @@ -735,7 +735,7 @@ attributes {topology = #mpmd.topology<<"m1" : <["x"=2]>>, <"m2" : <["y"=2]>>>} { // CHECK-DAG: %[[ASSIGN_INDEX:.*]] = mpmd.assign %[[UNASSIGN_INDEX]] : (tensor) -> !mpmd.mesh_tensor<"m1", tensor> // CHECK-DAG: %[[ASSIGN_ARG1:.*]] = mpmd.assign %[[UNASSIGN_ARG1]] : (tensor) -> !mpmd.mesh_tensor<"m1", tensor> - // CHECK-DAG: %[[ADD_FRAG1:.*]] = mpmd.fragment (%[[ASSIGN_INDEX]], %[[ASSIGN_ARG1]]) (%arg3: tensor, %arg4: tensor) { + // CHECK-DAG: %[[ADD_FRAG1:.*]] = mpmd.fragment (%[[ASSIGN_INDEX]], %[[ASSIGN_ARG1]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg3: tensor, %arg4: tensor) { // CHECK-NEXT: %[[CONST:.*]] = stablehlo.constant dense<1> : tensor // CHECK-NEXT: %[[ADD0:.*]] = stablehlo.add %arg4, %[[CONST]] : tensor // CHECK-NEXT: %[[ADD1:.*]] = stablehlo.add %[[ADD0]], %arg3 : tensor @@ -744,7 +744,7 @@ attributes {topology = #mpmd.topology<<"m1" : <["x"=2]>>, <"m2" : <["y"=2]>>>} { // CHECK-DAG: %[[F0_UNASSIGN:.*]] = mpmd.unassign %[[ADD_FRAG1]] : (!mpmd.mesh_tensor<"m1", tensor>) -> tensor // CHECK-DAG: %[[ASSIGN_ARG2:.*]] = mpmd.assign %[[UNASSIGN_ARG2]] : (tensor) -> !mpmd.mesh_tensor<"m1", tensor> - // CHECK-DAG: %[[ADD_FRAG2:.*]] = mpmd.fragment (%[[ASSIGN_ARG2]]) (%arg3: tensor) { + // CHECK-DAG: %[[ADD_FRAG2:.*]] = mpmd.fragment (%[[ASSIGN_ARG2]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg3: tensor) { // CHECK-NEXT: %[[CONST:.*]] = stablehlo.constant dense<1> : tensor // CHECK-NEXT: %[[ADD0:.*]] = stablehlo.add %arg3, %[[CONST]] : tensor // CHECK-NEXT: mpmd.return %[[ADD0]] : tensor @@ -800,7 +800,7 @@ attributes {topology = #mpmd.topology<<"m1" : <["x"=1]>>>} { // CHECK-DAG: %[[UNASSIGN_ARG1:.*]] = mpmd.unassign %arg1 : (!mpmd.mesh_tensor<"m1", tensor>) -> tensor // CHECK-DAG: %[[ASSIGN_ARG0:.*]] = mpmd.assign %[[UNASSIGN_ARG0]] : (tensor) -> !mpmd.mesh_tensor<"m1", tensor> // CHECK-DAG: %[[ASSIGN_ARG1:.*]] = mpmd.assign %[[UNASSIGN_ARG1]] : (tensor) -> !mpmd.mesh_tensor<"m1", tensor> - // CHECK-DAG: %[[ADD_FRAG0:.*]] = mpmd.fragment (%[[ASSIGN_ARG0]], %[[ASSIGN_ARG1]]) (%arg2: tensor, %arg3: tensor) { + // CHECK-DAG: %[[ADD_FRAG0:.*]] = mpmd.fragment (%[[ASSIGN_ARG0]], %[[ASSIGN_ARG1]]) {mpmd.inferred_by = ["rewrite_using_analysis"]} (%arg2: tensor, %arg3: tensor) { // CHECK-NEXT: %[[ADD0:.*]] = stablehlo.add %arg2, %arg3 : tensor // CHECK-NEXT: mpmd.return %[[ADD0]] : tensor // CHECK-NEXT: } : (!mpmd.mesh_tensor<"m1", tensor>, !mpmd.mesh_tensor<"m1", tensor>) -> !mpmd.mesh_tensor<"m1", tensor> diff --git a/shardy/dialect/mpmd/transforms/sharding_propagation/extract_reshards_from_inter_mesh_transfers.cc b/shardy/dialect/mpmd/transforms/sharding_propagation/extract_reshards_from_inter_mesh_transfers.cc index cbd9fcb55..6aaabfcbb 100644 --- a/shardy/dialect/mpmd/transforms/sharding_propagation/extract_reshards_from_inter_mesh_transfers.cc +++ b/shardy/dialect/mpmd/transforms/sharding_propagation/extract_reshards_from_inter_mesh_transfers.cc @@ -152,9 +152,9 @@ void HandleTransfer(TransferOp transfer, RewriterBase& rewriter, // %t = transfer (%r) : () -> OpOperand& operand = transfer->getOpOperand(0); Value value = operand.get(); - MeshTensorType new_operand_type = MeshTensorType::get( - rewriter.getContext(), src_mesh_type.getMeshName(), - dst_mesh_type.getRankedTensorType()); + MeshTensorType new_operand_type = + MeshTensorType::get(rewriter.getContext(), src_mesh_type.getMeshName(), + dst_mesh_type.getRankedTensorType()); if (isa(value) || isa(value.getDefiningOp())) { // We do not want to update the type of the block argument, not to // interfere with the function signature (and its shardings). @@ -162,6 +162,9 @@ void HandleTransfer(TransferOp transfer, RewriterBase& rewriter, FragmentOp reshard = FragmentOp::createMeshFragmentWithGlobalBody( value.getLoc(), /*user_origin=*/{}, src_mesh_type.getMeshName(), value, new_operand_type, rewriter, reshard_body); + reshard->setAttr( + kInferredByAttr, + rewriter.getArrayAttr({rewriter.getStringAttr("extract_reshards")})); reshard.setUserSpecifiedResultSharding(0, dst_sharding_or_null); operand.set(reshard.getResult(0)); return; @@ -177,6 +180,9 @@ void HandleTransfer(TransferOp transfer, RewriterBase& rewriter, FragmentOp reshard = FragmentOp::createMeshFragmentWithGlobalBody( value.getLoc(), /*user_origin=*/{}, new_operand_type.getMeshName(), value, value.getType(), rewriter, reshard_body); + reshard->setAttr( + kInferredByAttr, + rewriter.getArrayAttr({rewriter.getStringAttr("extract_reshards")})); reshard.setUserSpecifiedResultSharding(0, src_sharding_or_null); rewriter.replaceUsesWithIf( value, reshard.getResult(0), [transfer](OpOperand& use) { @@ -231,6 +237,9 @@ void HandleTransfer(TransferOp transfer, RewriterBase& rewriter, MeshTensorType::get(rewriter.getContext(), dst_mesh_type.getMeshName(), dst_mesh_type.getRankedTensorType()), rewriter, reshard_body); + reshard->setAttr( + kInferredByAttr, + rewriter.getArrayAttr({rewriter.getStringAttr("extract_reshards")})); reshard.setUserSpecifiedResultSharding(0, dst_sharding_or_null); rewriter.replaceUsesWithIf( new_transfer.getResult(), reshard.getResult(0), [](OpOperand& use) { diff --git a/shardy/dialect/mpmd/transforms/sharding_propagation/test/extract_reshard_from_inter_mesh_transfer.mlir b/shardy/dialect/mpmd/transforms/sharding_propagation/test/extract_reshard_from_inter_mesh_transfer.mlir index 298498429..efbe0e2a5 100644 --- a/shardy/dialect/mpmd/transforms/sharding_propagation/test/extract_reshard_from_inter_mesh_transfer.mlir +++ b/shardy/dialect/mpmd/transforms/sharding_propagation/test/extract_reshard_from_inter_mesh_transfer.mlir @@ -12,7 +12,7 @@ func.func @reshard_on_consumer_when_same_local_type_size(%arg0: !mpmd.mesh_tenso attributes {"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["x"=2]>>>} { // CHECK-NEXT: %[[TRANSFER_RESULT:.*]] = mpmd.transfer {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {?}]>]>} %arg0 : (!mpmd.mesh_tensor<"m1", tensor<4x8xui32>>) -> !mpmd.mesh_tensor<"m2", tensor<4x8xui32>> - // CHECK-NEXT: %[[RESHARD_RESULT:.*]] = mpmd.fragment]> (%[[TRANSFER_RESULT]]) (%arg1: tensor<4x8xui32>) { + // CHECK-NEXT: %[[RESHARD_RESULT:.*]] = mpmd.fragment]> (%[[TRANSFER_RESULT]]) {mpmd.inferred_by = ["extract_reshards"]} (%arg1: tensor<4x8xui32>) { // CHECK-NEXT: mpmd.return %arg1 : tensor<4x8xui32> // CHECK-NEXT: } : (!mpmd.mesh_tensor<"m2", tensor<4x8xui32>>) -> !mpmd.mesh_tensor<"m2", tensor<4x8xui32>> %0 = mpmd.transfer {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}, {"x"}]>]>} %arg0: (!mpmd.mesh_tensor<"m1", tensor<4x8xui32>>) -> !mpmd.mesh_tensor<"m2", tensor<4x8xui32>> @@ -32,7 +32,7 @@ func.func @reshard_on_producer_when_local_type_smaller_on_producer(%arg0: !mpmd. -> (!mpmd.mesh_tensor<"m2", tensor<4x8xui32>> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {?}]>}, !mpmd.mesh_tensor<"m1", tensor<4x8xui32>> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {?}]>}) attributes {"topology"=#mpmd.topology<<"m1": <["x"=2, "y"=4]>>, <"m2": <["x"=2, "y"=4]>>>} { - // CHECK-NEXT: %[[RESHARD_RESULT:.*]] = mpmd.fragment]> (%arg0) (%arg1: tensor<4x8xui32>) { + // CHECK-NEXT: %[[RESHARD_RESULT:.*]] = mpmd.fragment]> (%arg0) {mpmd.inferred_by = ["extract_reshards"]} (%arg1: tensor<4x8xui32>) { // CHECK-NEXT: mpmd.return %arg1 : tensor<4x8xui32> // CHECK-NEXT: } : (!mpmd.mesh_tensor<"m1", tensor<4x8xui32>>) -> !mpmd.mesh_tensor<"m1", tensor<4x8xui32>> // CHECK-NEXT: %[[TRANSFER_RESULT:.*]] = mpmd.transfer {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}, {?}]>]>} %0 : (!mpmd.mesh_tensor<"m1", tensor<4x8xui32>>) -> !mpmd.mesh_tensor<"m2", tensor<4x8xui32>> @@ -65,7 +65,7 @@ func.func @intra_mesh_transfer_introduces_reshard(%arg0: !mpmd.mesh_tensor<"m1", attributes {"topology"=#mpmd.topology<<"m1": <["x"=2]>>>} { // CHECK-NEXT: %[[TRANSFER_RESULT:.*]] = mpmd.transfer {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {?}]>]>} %arg0 : (!mpmd.mesh_tensor<"m1", tensor<4x8xui32>>) -> !mpmd.mesh_tensor<"m1", tensor<4x8xui32>> - // CHECK-NEXT: %[[RESHARD_RESULT:.*]] = mpmd.fragment]> (%[[TRANSFER_RESULT]]) (%arg1: tensor<4x8xui32>) { + // CHECK-NEXT: %[[RESHARD_RESULT:.*]] = mpmd.fragment]> (%[[TRANSFER_RESULT]]) {mpmd.inferred_by = ["extract_reshards"]} (%arg1: tensor<4x8xui32>) { // CHECK-NEXT: mpmd.return %arg1 : tensor<4x8xui32> // CHECK-NEXT: } : (!mpmd.mesh_tensor<"m1", tensor<4x8xui32>>) -> !mpmd.mesh_tensor<"m1", tensor<4x8xui32>> %0 = mpmd.transfer {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}, {"x"}]>]>} %arg0 : (!mpmd.mesh_tensor<"m1", tensor<4x8xui32>>) -> !mpmd.mesh_tensor<"m1", tensor<4x8xui32>> @@ -80,7 +80,7 @@ func.func @same_mesh_different_memory_kind(%arg0: !mpmd.mesh_tensor<"m1", tensor // CHECK-NEXT: %[[TRANSFER_RESULT:.*]] = mpmd.transfer {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {?}]>]>} %arg0 : // CHECK-SAME: (!mpmd.mesh_tensor<"m1", tensor<4x8xui32>, memory_kind="pinned_host">) -> // CHECK-SAME: !mpmd.mesh_tensor<"m1", tensor<4x8xui32>> - // CHECK-NEXT: %[[RESHARD_RESULT:.*]] = mpmd.fragment]> (%[[TRANSFER_RESULT]]) (%arg1: tensor<4x8xui32>) + // CHECK-NEXT: %[[RESHARD_RESULT:.*]] = mpmd.fragment]> (%[[TRANSFER_RESULT]]) {mpmd.inferred_by = ["extract_reshards"]} (%arg1: tensor<4x8xui32>) %0 = mpmd.transfer {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}, {"x"}]>]>} %arg0 : (!mpmd.mesh_tensor<"m1", tensor<4x8xui32>, memory_kind="pinned_host">) -> !mpmd.mesh_tensor<"m1", tensor<4x8xui32>> func.return %0 : !mpmd.mesh_tensor<"m1", tensor<4x8xui32>> } @@ -124,7 +124,7 @@ func.func @reshard_on_consumer_fragments_and_with_new_fragment(%arg0: !mpmd.mesh // Create a new reshard fragment as the transfer is used by another transfer // and by the return op. - // CHECK-NEXT: %[[RESHARD:.*]] = mpmd.fragment]> (%[[TRANSFER]]) + // CHECK-NEXT: %[[RESHARD:.*]] = mpmd.fragment]> (%[[TRANSFER]]) {mpmd.inferred_by = ["extract_reshards"]} // CHECK-NEXT: mpmd.return // CHECK-NEXT: } : (!mpmd.mesh_tensor<"m2", tensor<4x8xui32>>) -> !mpmd.mesh_tensor<"m2", tensor<4x8xui32>> %t = mpmd.transfer {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}, {"x"}]>]>} %arg0 : (!mpmd.mesh_tensor<"m1", tensor<4x8xui32>>) -> !mpmd.mesh_tensor<"m2", tensor<4x8xui32>> @@ -194,7 +194,7 @@ func.func @create_fragment_when_producer_is_transfer(%arg0: !mpmd.mesh_tensor<"m attributes {"topology"=#mpmd.topology<<"m1": <["x"=2, "y"=4]>>, <"m2": <["x"=2, "y"=4]>>>} { // CHECK-NEXT: %[[TRANSFER:.*]] = mpmd.transfer {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {?}]>]>} %arg0 : (!mpmd.mesh_tensor<"m1", tensor<4x8xui32>>) -> !mpmd.mesh_tensor<"m1", tensor<4x8xui32>> - // CHECK-NEXT: %[[RESHARD:.*]] = mpmd.fragment]> (%[[TRANSFER]]) (%arg1: tensor<4x8xui32>) { + // CHECK-NEXT: %[[RESHARD:.*]] = mpmd.fragment]> (%[[TRANSFER]]) {mpmd.inferred_by = ["extract_reshards"]} (%arg1: tensor<4x8xui32>) { // CHECK-NEXT: mpmd.return %arg1 : tensor<4x8xui32> // CHECK-NEXT: } : (!mpmd.mesh_tensor<"m1", tensor<4x8xui32>>) -> !mpmd.mesh_tensor<"m1", tensor<4x8xui32>> // CHECK-NEXT: %[[TRANSFER_RESULT:.*]] = mpmd.transfer {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}, {?}]>]>} %1 : (!mpmd.mesh_tensor<"m1", tensor<4x8xui32>>) -> !mpmd.mesh_tensor<"m2", tensor<4x8xui32>> diff --git a/shardy/dialect/mpmd/transforms/sharding_propagation/test/sharding_propagation_pipeline.mlir b/shardy/dialect/mpmd/transforms/sharding_propagation/test/sharding_propagation_pipeline.mlir index 4d2b53fcf..203d99f07 100644 --- a/shardy/dialect/mpmd/transforms/sharding_propagation/test/sharding_propagation_pipeline.mlir +++ b/shardy/dialect/mpmd/transforms/sharding_propagation/test/sharding_propagation_pipeline.mlir @@ -335,7 +335,7 @@ func.func @introduce_reshard_for_transfer_operand_and_result_with_different_shar } : (!mesh_1_tensor_8_2_f32, !mesh_1_tensor_8_2_f32) -> !mesh_1_tensor_8_2_f32 // CHECK: %[[FRAG:.*]] = mpmd.fragment (%arg0, %arg1) (%arg2: tensor<8x2xf32>, %arg3: tensor<8x2xf32>) { // CHECK: %[[TRANSFER:.*]] = mpmd.transfer %0 : (!mpmd.mesh_tensor<"mesh1", tensor<8x2xf32>, sharding=<@mesh, [{"devices"}, {}]>>) -> !mpmd.mesh_tensor<"mesh2", tensor<8x2xf32>, sharding=<@mesh, [{"devices"}, {}]>> - // CHECK: %[[RESHARD:.*]] = mpmd.fragment (%[[TRANSFER]]) (%arg2: tensor<8x2xf32>) { + // CHECK: %[[RESHARD:.*]] = mpmd.fragment (%[[TRANSFER]]) {mpmd.inferred_by = ["extract_reshards"]} (%arg2: tensor<8x2xf32>) { // CHECK: mpmd.return %arg2 : tensor<8x2xf32> // CHECK: } : (!mpmd.mesh_tensor<"mesh2", tensor<8x2xf32>, sharding=<@mesh, [{"devices"}, {}]>>) -> !mpmd.mesh_tensor<"mesh2", tensor<8x2xf32>, sharding=<@mesh, [{}, {}]>> // CHECK: return %[[FRAG]], %[[RESHARD]] : !mpmd.mesh_tensor<"mesh1", tensor<8x2xf32>, sharding=<@mesh, [{"devices"}, {}]>>, !mpmd.mesh_tensor<"mesh2", tensor<8x2xf32>, sharding=<@mesh, [{}, {}]>> @@ -528,7 +528,7 @@ sdy.mesh @mesh = <["x"=8]> func.func @introduce_reshard_for_arg( %arg0: !mesh_1_tensor_16_32_f32 {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}) -> (!mesh_2_tensor_16_32_f32 {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) attributes {topology=#homogenous_topology} { - // CHECK: %[[RESHARD:.*]] = mpmd.fragment (%arg0) (%arg1: tensor<16x32xf32>) { + // CHECK: %[[RESHARD:.*]] = mpmd.fragment (%arg0) {mpmd.inferred_by = ["extract_reshards"]} (%arg1: tensor<16x32xf32>) { // CHECK-NEXT: mpmd.return %arg1 : tensor<16x32xf32> // CHECK-NEXT: } : (!mpmd.mesh_tensor<"m1", tensor<16x32xf32>, sharding=<@mesh, [{}, {}]>>) -> // CHECK-SAME: !mpmd.mesh_tensor<"m1", tensor<16x32xf32>, sharding=<@mesh, [{"x"}, {}]>> @@ -717,7 +717,7 @@ func.func @return_value_used_in_another_fragment(%arg0: !mesh_1_tensor {sdy.shar // CHECK-SAME: (!mpmd.mesh_tensor<"m1", tensor<4xf32>, sharding=<@mesh, [{"y"}]>>) // CHECK-SAME: -> !mpmd.mesh_tensor<"m2", tensor<4xf32>, sharding=<@mesh, [{"y"}]>> %0 = mpmd.transfer %arg0 : (!mesh_1_tensor) -> !mesh_2_tensor // %arg0 is sharded by y axis - // CHECK: mpmd.fragment (%[[TRANSFER_RESULT]]) (%arg2: tensor<4xf32>) { + // CHECK: mpmd.fragment (%[[TRANSFER_RESULT]]) {mpmd.inferred_by = ["extract_reshards"]} (%arg2: tensor<4xf32>) { // CHECK-NEXT: mpmd.return %arg2 : tensor<4xf32> // CHECK-NEXT: } : (!mpmd.mesh_tensor<"m2", tensor<4xf32>, sharding=<@mesh, [{"y"}]>>) // CHECK-SAME: -> !mpmd.mesh_tensor<"m2", tensor<4xf32>, sharding=<@mesh, [{"x"}]>>