Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion shardy/dialect/mpmd/ir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ SmallVector<MpmdDataflowEdge> GetMpmdDataflowEdges(FuncOp func_op) {

FragmentOp WrapOpWithFragment(
Operation* op, StringRef mesh_name, RewriterBase& rewriter,
std::function<bool(OpOperand&)> should_replace_use) {
StringRef inferred_by, std::function<bool(OpOperand&)> should_replace_use) {
// We set the insertion point right before `op` so assigns of operands will be
// in the right place regardless of previous insertion point.
rewriter.setInsertionPoint(op);
Expand Down Expand Up @@ -437,6 +437,10 @@ FragmentOp WrapOpWithFragment(
return block_builder.clone(*op, mapping)->getResults();
});

fragment_op->setAttr(
kInferredByAttr,
rewriter.getArrayAttr({rewriter.getStringAttr(inferred_by)}));

// Unassign all fragment results and replace all uses of `op` with the
// corresponding unassign op for which `should_replace_use` returns true.
for (auto [original_result, fragment_result] :
Expand Down
16 changes: 9 additions & 7 deletions shardy/dialect/mpmd/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,16 @@ limitations under the License.

namespace mlir::mpmd {

// Globsl sdy mesh name.
constexpr StringRef kGlobalMeshName = "mesh";

// Global sdy mesh name.
inline constexpr StringRef kGlobalMeshName = "mesh";
// The function attribute that holds the SPMD mesh.
constexpr StringRef kMeshShapeAttr = "mesh_shape";
inline constexpr StringRef kMeshShapeAttr = "mesh_shape";
// The function attribute that holds the MPMD topology.
constexpr StringRef kTopologyAttr = "topology";

inline constexpr StringRef kTopologyAttr = "topology";

// The suffix of the mesh name for a CPU mesh.
// LINT.IfChange
constexpr StringRef kCpuMeshSuffix = "/cpu";
inline constexpr StringRef kCpuMeshSuffix = "/cpu";
// LINT.ThenChange(
// https://github.com/openxla/shardy/blob/main/shardy/integrations/python/jax/mpmd/types.py
// )
Expand All @@ -76,6 +74,9 @@ inline constexpr StringRef kRematAttributeName = "remat";

inline constexpr StringRef kJaxResultInfoAttr = "jax.result_info";

// The attribute that holds the list of pass names that inferred a fragment.
inline constexpr StringRef kInferredByAttr = "mpmd.inferred_by";

template <typename... Args>
std::string StrCat(Args&&... args) {
std::string result;
Expand Down Expand Up @@ -261,6 +262,7 @@ SmallVector<MpmdDataflowEdge> GetMpmdDataflowEdges(func::FuncOp func_op);
// `should_replace_use` returns true.
FragmentOp WrapOpWithFragment(
Operation* op, StringRef mesh_name, RewriterBase& rewriter,
StringRef inferred_by,
std::function<bool(OpOperand&)> should_replace_use = [](OpOperand&) {
return true;
});
Expand Down
1 change: 1 addition & 0 deletions shardy/dialect/mpmd/transforms/common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ cc_library(
"split_bwd_fragments.cc",
"uniquify_function_inputs_outputs.cc",
"unroll_for_loops.cc",
"wrap_block_arg_returns.cc",
],
hdrs = [
"merge_fragments.h",
Expand Down
32 changes: 32 additions & 0 deletions shardy/dialect/mpmd/transforms/common/merge_fragments.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include <vector>

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
Expand Down Expand Up @@ -182,6 +183,31 @@ std::optional<int> MergeCallCounters(FragmentOp producer_op,
return std::nullopt;
}

std::optional<ArrayAttr> MergeInferredByAttributes(FragmentOp producer_op,
FragmentOp consumer_op) {
ArrayAttr producer_inferred_by =
producer_op->getAttrOfType<ArrayAttr>(kInferredByAttr);
ArrayAttr consumer_inferred_by =
consumer_op->getAttrOfType<ArrayAttr>(kInferredByAttr);

if (!producer_inferred_by && !consumer_inferred_by) {
return std::nullopt;
}

llvm::SetVector<Attribute> combined_inferred_by;
if (producer_inferred_by) {
combined_inferred_by.insert(producer_inferred_by.begin(),
producer_inferred_by.end());
}
if (consumer_inferred_by) {
combined_inferred_by.insert(consumer_inferred_by.begin(),
consumer_inferred_by.end());
}

IRRewriter rewriter(producer_op.getContext());
return rewriter.getArrayAttr(combined_inferred_by.takeVector());
}

// Returns a list of attributes that must be preserved in the merged fragment.
// Note: origins are preserved by default and require no extra work.
SmallVector<std::pair<StringRef, Attribute>> MergedAttributes(
Expand All @@ -194,6 +220,12 @@ SmallVector<std::pair<StringRef, Attribute>> MergedAttributes(
attributes.emplace_back(kCallCounterAttrName,
rewriter.getUI32IntegerAttr(*merged_call_count));
}

if (std::optional<ArrayAttr> merged_inferred_by =
MergeInferredByAttributes(producer_op, consumer_op)) {
attributes.emplace_back(kInferredByAttr, *merged_inferred_by);
}

return attributes;
}

Expand Down
16 changes: 16 additions & 0 deletions shardy/dialect/mpmd/transforms/common/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,22 @@ def RemoveTransferCyclesPass :
}];
}

def WrapBlockArgReturnsPass :
PassBase<"mpmd-wrap-block-arg-returns", "DistributedFunctionPass"> {
let summary = "Wraps block arguments that are directly returned by the "
"function in identity fragments.";
let description = [{
If a function directly returns a block argument, this pass creates an
identity fragment for that block argument, ensuring that sharding
constraints can be applied to an op result rather than a block argument.

This guarantees that values are passed by value to the function, not by
reference.
}];

let dependentDialects = ["mlir::mpmd::MpmdDialect"];
}

// TODO(jupvfranco): we should create these copies using ifrt_ir reshard. We are
// not ready for that yet as we need better support for donation in the
// presence of reshards.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func.func @single_mesh_one_return_operand(%arg0: !mesh_1_tensor) -> (!mesh_1_ten
} {
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m1", origin=["f2"]>
// CHECK: %[[UF:.*]]:2 = mpmd.fragment<mesh="m1", origin=[]> (%[[F1]]) (%arg1: tensor<4xf32>) {
// CHECK: %[[UF:.*]]:2 = mpmd.fragment<mesh="m1", origin=[]> (%[[F1]]) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xf32>) {
// CHECK: mpmd.return %arg1, %arg1 : tensor<4xf32>, tensor<4xf32>
// CHECK: %[[F2]], %[[UF]]#0, %[[UF]]#1
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg1: tensor<4xf32>) {
Expand All @@ -49,7 +49,7 @@ func.func @needs_fragment_for_m1_with_many_values(%arg0: !mesh_1_tensor, %arg1:
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m2", origin=["f2"]>
// CHECK: %[[F3:.*]] = mpmd.fragment<mesh="m1", origin=["f3"]>
// CHECK: %[[UF:.*]]:5 = mpmd.fragment<mesh="m1", origin=[]> (%[[F1]], %[[F3]]) (%[[A1:.*]]: tensor<4xf32>, %[[A2:.*]]: tensor<4xf32>)
// CHECK: %[[UF:.*]]:5 = mpmd.fragment<mesh="m1", origin=[]> (%[[F1]], %[[F3]]) {mpmd.inferred_by = ["uniquify"]} (%[[A1:.*]]: tensor<4xf32>, %[[A2:.*]]: tensor<4xf32>)
// CHECK-NEXT: mpmd.return %[[A1]], %[[A1]], %[[A2]], %[[A2]], %[[A2]]
// CHECK-NEXT: }
// CHECK-NEXT: return %[[F2]], %[[UF]]#0, %[[UF]]#2, %[[UF]]#1, %[[UF]]#3, %[[UF]]#4
Expand All @@ -70,8 +70,8 @@ func.func @needs_fragment_for_m1_and_m2(%arg0: !mesh_1_tensor, %arg1: !mesh_2_te
) -> (!mesh_1_tensor, !mesh_2_tensor, !mesh_2_tensor, !mesh_1_tensor, !mesh_1_tensor, !mesh_1_tensor) attributes {
"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["x"=2]>>>
} {
// CHECK: %[[UF1:.*]]:4 = mpmd.fragment<mesh="m1", origin=[]>
// CHECK: %[[UF2:.*]]:2 = mpmd.fragment<mesh="m2", origin=[]>
// CHECK: %[[UF1:.*]]:4 = mpmd.fragment<mesh="m1", origin=[]> ({{.*}}) {mpmd.inferred_by = ["uniquify"]}
// CHECK: %[[UF2:.*]]:2 = mpmd.fragment<mesh="m2", origin=[]> ({{.*}}) {mpmd.inferred_by = ["uniquify"]}
// CHECK: return %[[UF1]]#0, %[[UF2]]#0, %[[UF2]]#1, %[[UF1]]#2, %[[UF1]]#1, %[[UF1]]#3
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg2: tensor<4xf32>) {
mpmd.return %arg2 : tensor<4xf32>
Expand All @@ -97,7 +97,7 @@ func.func @single_mesh_one_return_operand_with_global_view(%arg0: !dist_mesh_ten
} {
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m1", origin=["f2"]>
// CHECK: %[[UF:.*]]:2 = mpmd.fragment<mesh="m1", origin=[]> (%[[F1]]) (%arg1: tensor<4xf32>) {
// CHECK: %[[UF:.*]]:2 = mpmd.fragment<mesh="m1", origin=[]> (%[[F1]]) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xf32>) {
// CHECK: mpmd.return %arg1, %arg1 : tensor<4xf32>, tensor<4xf32>
// CHECK: %[[F2]], %[[UF]]#0, %[[UF]]#1
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg1: tensor<4xf32>) {
Expand All @@ -122,7 +122,7 @@ func.func @f(%arg0: !mesh_tensor) -> (!mesh_tensor, !mesh_tensor, !mesh_tensor)
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m", origin=["f"]> (%arg0) (%arg1: tensor<4xui32>) {
// CHECK-NEXT: return %arg1
// CHECK-NEXT: }
// CHECK-NEXT: %[[F2:.*]]:2 = mpmd.fragment<mesh="m", origin=[]> (%arg0) (%arg1: tensor<4xui32>) {
// CHECK-NEXT: %[[F2:.*]]:2 = mpmd.fragment<mesh="m", origin=[]> (%arg0) {mpmd.inferred_by = ["uniquify"]} (%arg1: tensor<4xui32>) {
// CHECK-NEXT: return %arg1, %arg1
// CHECK-NEXT: }
// CHECK-NEXT: return %[[F2]]#0, %[[F1]], %[[F2]]#1
Expand All @@ -140,9 +140,8 @@ func.func @f(%arg0: !mesh_tensor) -> (!mesh_tensor, !mesh_tensor, !mesh_tensor)
func.func @identity_function(%arg0: !mesh_tensor) -> !mesh_tensor
attributes {"topology"=#mpmd.topology<<"m": <["x"=2]>>>}
{
// CHECK-NEXT: %[[F:.*]] = mpmd.fragment<mesh="m", origin=[]> (%arg0) (%arg1: tensor<4xui32>) {
// CHECK-NEXT: return %arg1
// CHECK-NEXT: }
// CHECK-NEXT: return %[[F]]
// Block-arg wrapping is handled by a separate pass. The uniquify pass is a
// no-op here since %arg0 appears only once.
// CHECK-NEXT: return %arg0
func.return %arg0 : !mesh_tensor
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// RUN: mpmd_opt %s -mpmd-wrap-block-arg-returns -split-input-file 2>&1 | FileCheck %s

!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4xf32>>
!mesh_2_tensor = !mpmd.mesh_tensor<"m2", tensor<4xf32>>

// CHECK-LABEL: func @no_work_needed
func.func @no_work_needed(%arg0: !mesh_1_tensor, %arg1: !mesh_2_tensor) -> (!mesh_1_tensor, !mesh_2_tensor) attributes {
"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["x"=2]>>>
} {
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
// CHECK: %[[F2:.*]] = mpmd.fragment<mesh="m2", origin=["f2"]>
// CHECK: return %[[F1]], %[[F2]]
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg2: tensor<4xf32>) {
%1 = stablehlo.add %arg2, %arg2 : tensor<4xf32>
mpmd.return %1 : tensor<4xf32>
} : (!mesh_1_tensor) -> !mesh_1_tensor
%1 = mpmd.fragment<mesh="m2", origin=["f2"]> (%arg1) (%arg2: tensor<4xf32>) {
%1 = stablehlo.add %arg2, %arg2 : tensor<4xf32>
mpmd.return %1 : tensor<4xf32>
} : (!mesh_2_tensor) -> !mesh_2_tensor
return %0, %1 : !mesh_1_tensor, !mesh_2_tensor
}

// -----

!mesh_tensor = !mpmd.mesh_tensor<"m", tensor<4xui32>, sharding=<@mesh, [{"x"}]>>

module {

// CHECK-LABEL: func @identity_function
func.func @identity_function(%arg0: !mesh_tensor) -> !mesh_tensor
attributes {"topology"=#mpmd.topology<<"m": <["x"=2]>>>}
{
// CHECK-NEXT: %[[F:.*]] = mpmd.fragment<mesh="m", origin=[]> (%arg0) {mpmd.inferred_by = ["wrap_block_arg_returns"]} (%arg1: tensor<4xui32>) {
// CHECK-NEXT: return %arg1
// CHECK-NEXT: }
// CHECK-NEXT: return %[[F]]
func.return %arg0 : !mesh_tensor
}

}

// -----

!mesh_1_tensor = !mpmd.mesh_tensor<"m1", tensor<4xf32>>
!mesh_2_tensor = !mpmd.mesh_tensor<"m2", tensor<4xf32>>

// CHECK-LABEL: func @mixed_returns
func.func @mixed_returns(%arg0: !mesh_1_tensor, %arg1: !mesh_2_tensor) -> (!mesh_1_tensor, !mesh_1_tensor, !mesh_2_tensor) attributes {
"topology"=#mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["x"=2]>>>
} {
// CHECK-NEXT: %[[F1:.*]] = mpmd.fragment<mesh="m1", origin=["f1"]>
// CHECK: %[[WRAP1:.*]] = mpmd.fragment<mesh="m1", origin=[]> (%arg0) {mpmd.inferred_by = ["wrap_block_arg_returns"]}
// CHECK: %[[WRAP2:.*]] = mpmd.fragment<mesh="m2", origin=[]> (%arg1) {mpmd.inferred_by = ["wrap_block_arg_returns"]}
// CHECK: return %[[WRAP1]], %[[F1]], %[[WRAP2]]
%0 = mpmd.fragment<mesh="m1", origin=["f1"]> (%arg0) (%arg2: tensor<4xf32>) {
%1 = stablehlo.add %arg2, %arg2 : tensor<4xf32>
mpmd.return %1 : tensor<4xf32>
} : (!mesh_1_tensor) -> !mesh_1_tensor
return %arg0, %0, %arg1 : !mesh_1_tensor, !mesh_1_tensor, !mesh_2_tensor
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
Expand All @@ -30,7 +31,6 @@ limitations under the License.
#include "shardy/dialect/mpmd/ir/utils.h"
#include "shardy/dialect/mpmd/transforms/common/passes.h" // IWYU pragma: keep
#include "shardy/dialect/sdy/ir/dialect.h"
#include "mlir/IR/MLIRContext.h"

namespace mlir::mpmd {

Expand All @@ -44,16 +44,6 @@ using ValueToReturnIndices = llvm::MapVector<Value, SmallVector<int64_t>>;
void CreateReturnFragmentForMesh(StringRef mesh_name, Operation* return_op,
ValueToReturnIndices& value_to_return_indices,
OpBuilder& builder) {
// We remove any entries that require no work, in order to avoid too many
// checks.
value_to_return_indices.remove_if([](const auto& it) {
if (it.second.size() == 1) {
Value v = it.first;
return !isa<BlockArgument>(v);
}
return it.second.empty();
});

builder.setInsertionPoint(return_op);
SmallVector<Value> fragment_operands;
fragment_operands.reserve(value_to_return_indices.size());
Expand All @@ -65,15 +55,14 @@ void CreateReturnFragmentForMesh(StringRef mesh_name, Operation* return_op,
cast<MeshTensorType>(value.getType()));
}

if (fragment_operands.empty()) {
return;
}

auto loc = return_op->getLoc();
auto fragment_op = FragmentOp::create(
builder, loc, fragment_return_types, fragment_operands,
/*user_origin=*/ArrayAttr::get(builder.getContext(), {}),
/*mesh_name=*/mesh_name, /*stage_id=*/IntegerAttr());
fragment_op->setAttr(
kInferredByAttr,
builder.getArrayAttr({builder.getStringAttr("uniquify")}));
Block& fragment_block = fragment_op.getRegion().emplaceBlock();

SmallVector<Value> returned_values;
Expand All @@ -88,8 +77,7 @@ void CreateReturnFragmentForMesh(StringRef mesh_name, Operation* return_op,
returned_values.insert(
returned_values.end(), return_indices.size(),
fragment_block.addArgument(
GetGlobalTensorTypeFromMeshType(value, mesh_attr),
value.getLoc()));
GetGlobalTensorTypeFromMeshType(value, mesh_attr), value.getLoc()));

for (int64_t index : return_indices) {
return_op->setOperand(index,
Expand Down Expand Up @@ -150,7 +138,7 @@ class UniquifyFunctionInputOutputsPass
using UniquifyFunctionInputsOutputsPassBase::
UniquifyFunctionInputsOutputsPassBase;

private:
protected:
void runOnFunc(func::FuncOp func_op) override {
if (!IsMpmdFunction(func_op)) {
// This is not the main function. Do nothing.
Expand Down Expand Up @@ -179,6 +167,9 @@ class UniquifyFunctionInputOutputsPass
OpBuilder builder(&getContext());
for (auto& [mesh_name, value_to_return_indices] :
value_to_return_indices_per_mesh) {
// Only keep entries that are returned more than once (need uniquifying).
value_to_return_indices.remove_if(
[](const auto& it) { return it.second.size() <= 1; });
CreateReturnFragmentForMesh(mesh_name, return_op, value_to_return_indices,
builder);
}
Expand Down
Loading
Loading