Skip to content

Commit be9aa17

Browse files
akuegeltensorflower-gardener
authored andcommitted
Use GetInPlaceInputOutputPairs from AliasInfo instead of HloDataflowAnalysis.
PiperOrigin-RevId: 837557364
1 parent ecaa6e1 commit be9aa17

10 files changed

Lines changed: 77 additions & 56 deletions

File tree

third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis_test.cc

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3079,7 +3079,10 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
30793079
reverse, {}, call, {}, &alias_info_));
30803080
}
30813081

3082-
using GetInPlaceInputOutputPairsTest = HloHardwareIndependentTestBase;
3082+
class GetInPlaceInputOutputPairsTest : public HloHardwareIndependentTestBase {
3083+
protected:
3084+
AliasInfo alias_info_;
3085+
};
30833086

30843087
TEST_F(GetInPlaceInputOutputPairsTest, DUS) {
30853088
const char* kModule = R"(
@@ -3095,7 +3098,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, DUS) {
30953098
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
30963099
HloInstruction* dus = module->entry_computation()->root_instruction();
30973100

3098-
auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(dus);
3101+
auto in_place_pairs = alias_info_.GetInPlaceInputOutputPairs(dus);
30993102
std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
31003103
expected_pairs.push_back({HloOperandIndex{0, {}}, {}});
31013104
EXPECT_EQ(in_place_pairs, expected_pairs);
@@ -3122,7 +3125,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, DUSFusion) {
31223125
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
31233126
HloInstruction* fusion = module->entry_computation()->root_instruction();
31243127

3125-
auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(fusion);
3128+
auto in_place_pairs = alias_info_.GetInPlaceInputOutputPairs(fusion);
31263129
std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
31273130
expected_pairs.push_back({HloOperandIndex{0, {}}, {}});
31283131
EXPECT_EQ(in_place_pairs, expected_pairs);
@@ -3150,7 +3153,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, DUSFusionWithOutputOperandAliasing) {
31503153
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
31513154
HloInstruction* fusion = module->entry_computation()->root_instruction();
31523155

3153-
auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(fusion);
3156+
auto in_place_pairs = alias_info_.GetInPlaceInputOutputPairs(fusion);
31543157
std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
31553158
expected_pairs.push_back({HloOperandIndex{0, {}}, {1}}); // discovered
31563159
expected_pairs.push_back({HloOperandIndex{1, {}}, {0}}); // annotated
@@ -3176,7 +3179,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, NonDUSFusion) {
31763179
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
31773180
HloInstruction* fusion = module->entry_computation()->root_instruction();
31783181

3179-
auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(fusion);
3182+
auto in_place_pairs = alias_info_.GetInPlaceInputOutputPairs(fusion);
31803183
EXPECT_THAT(in_place_pairs, IsEmpty());
31813184
}
31823185

@@ -3198,7 +3201,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, NonDUSFusionWithOutputOperandAliasing) {
31983201
)";
31993202
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
32003203
HloInstruction* fusion = module->entry_computation()->root_instruction();
3201-
auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(fusion);
3204+
auto in_place_pairs = alias_info_.GetInPlaceInputOutputPairs(fusion);
32023205

32033206
std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
32043207
expected_pairs.push_back({HloOperandIndex{0, {}}, {}});
@@ -3233,7 +3236,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, NestedDUSFusion) {
32333236
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
32343237
HloInstruction* fusion = module->entry_computation()->root_instruction();
32353238

3236-
auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(fusion);
3239+
auto in_place_pairs = alias_info_.GetInPlaceInputOutputPairs(fusion);
32373240
std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
32383241
expected_pairs.push_back({HloOperandIndex{0, {}}, {}});
32393242
EXPECT_EQ(in_place_pairs, expected_pairs);
@@ -3277,12 +3280,12 @@ TEST_F(GetInPlaceInputOutputPairsTest, NestedMultiOutputDUSFusion) {
32773280
HloInstruction* inner_fusion = FindInstruction(module.get(), "inner_fusion");
32783281

32793282
auto inner_in_place_pairs =
3280-
HloDataflowAnalysis::GetInPlaceInputOutputPairs(inner_fusion);
3283+
alias_info_.GetInPlaceInputOutputPairs(inner_fusion);
32813284
std::vector<std::pair<HloOperandIndex, ShapeIndex>> inner_expected_pairs;
32823285
inner_expected_pairs.push_back({HloOperandIndex{1, {1}}, {1}});
32833286
EXPECT_EQ(inner_in_place_pairs, inner_expected_pairs);
32843287

3285-
auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(fusion);
3288+
auto in_place_pairs = alias_info_.GetInPlaceInputOutputPairs(fusion);
32863289
std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
32873290
expected_pairs.push_back({HloOperandIndex{1, {0}}, {2}});
32883291
EXPECT_EQ(in_place_pairs, expected_pairs);
@@ -3319,7 +3322,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, NestedLoopWithAliasingInDUSFusion) {
33193322
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
33203323
HloInstruction* fusion = module->entry_computation()->root_instruction();
33213324

3322-
auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(fusion);
3325+
auto in_place_pairs = alias_info_.GetInPlaceInputOutputPairs(fusion);
33233326
std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
33243327
expected_pairs.push_back({HloOperandIndex{0, {0}}, {}});
33253328
EXPECT_EQ(in_place_pairs, expected_pairs);
@@ -3372,7 +3375,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, DUSLoopFusionWithCollective) {
33723375
)";
33733376
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
33743377
HloInstruction* fusion = module->entry_computation()->root_instruction();
3375-
auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(fusion);
3378+
auto in_place_pairs = alias_info_.GetInPlaceInputOutputPairs(fusion);
33763379
std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
33773380
expected_pairs.push_back({HloOperandIndex{0, {}}, {1}});
33783381
EXPECT_EQ(in_place_pairs, expected_pairs);
@@ -3423,7 +3426,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, DUSOutputFusionWithCollective) {
34233426
)";
34243427
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
34253428
HloInstruction* fusion = module->entry_computation()->root_instruction();
3426-
auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(fusion);
3429+
auto in_place_pairs = alias_info_.GetInPlaceInputOutputPairs(fusion);
34273430
std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
34283431
expected_pairs.push_back({HloOperandIndex{0, {}}, {1}});
34293432
EXPECT_EQ(in_place_pairs, expected_pairs);
@@ -3456,7 +3459,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, DUSLoopFusionWithBitcast) {
34563459
)";
34573460
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
34583461
HloInstruction* fusion = module->entry_computation()->root_instruction();
3459-
auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(fusion);
3462+
auto in_place_pairs = alias_info_.GetInPlaceInputOutputPairs(fusion);
34603463
std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
34613464
// p1 should be aliased with fusion1
34623465
expected_pairs.push_back({HloOperandIndex{1, {}}, {}});
@@ -3485,7 +3488,7 @@ ENTRY AllToAll {
34853488
module->entry_computation()->root_instruction();
34863489

34873490
auto in_place_pairs =
3488-
HloDataflowAnalysis::GetInPlaceInputOutputPairs(ragged_all_to_all);
3491+
alias_info_.GetInPlaceInputOutputPairs(ragged_all_to_all);
34893492
std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
34903493
expected_pairs.push_back({HloOperandIndex{1, {}}, {}});
34913494
EXPECT_EQ(in_place_pairs, expected_pairs);
@@ -3607,8 +3610,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, nvshmem_ar) {
36073610
const HloInstruction* ar_start =
36083611
module->entry_computation()->root_instruction()->operand(0);
36093612

3610-
auto in_place_pairs =
3611-
HloDataflowAnalysis::GetInPlaceInputOutputPairs(ar_start);
3613+
auto in_place_pairs = alias_info_.GetInPlaceInputOutputPairs(ar_start);
36123614
std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
36133615
// For nvshmem allreduce, we expect no aliasing for input and output buffers
36143616
// therefore empty inplace pairs.

third_party/xla/xla/hlo/tools/hlo_opt/opt_lib.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ void OptProvider::RegisterAllHardwareIndependentPasses() {
246246
RegisterPass<AsyncCollectiveCreator>(
247247
AsyncCollectiveCreator::CollectiveCreatorConfig());
248248
RegisterPass<BFloat16ConversionFolding>(
249-
/*bfloat16_support=*/bfloat16_support);
249+
/*bfloat16_support=*/bfloat16_support, alias_info_.get());
250250
RegisterPass<BFloat16MixedPrecisionRemoval>();
251251
RegisterPass<BFloat16Propagation>(/*bfloat16_support=*/bfloat16_support,
252252
alias_info_.get());

third_party/xla/xla/hlo/transforms/simplifiers/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ cc_library(
158158
"//xla:shape_util",
159159
"//xla:util",
160160
"//xla:xla_data_proto_cc",
161-
"//xla/hlo/analysis:hlo_dataflow_analysis",
161+
"//xla/hlo/analysis:alias_info",
162162
"//xla/hlo/ir:hlo",
163163
"//xla/hlo/pass:hlo_pass",
164164
"//xla/service:float_support",
@@ -178,6 +178,7 @@ xla_cc_test(
178178
":bfloat16_conversion_folding",
179179
"//xla:shape_util",
180180
"//xla:xla_data_proto_cc",
181+
"//xla/hlo/analysis:alias_info",
181182
"//xla/hlo/ir:hlo",
182183
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
183184
"//xla/hlo/testlib:test_helpers",

third_party/xla/xla/hlo/transforms/simplifiers/bfloat16_conversion_folding.cc

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ limitations under the License.
2323
#include "absl/status/status.h"
2424
#include "absl/status/statusor.h"
2525
#include "absl/strings/string_view.h"
26-
#include "xla/hlo/analysis/hlo_dataflow_analysis.h"
26+
#include "xla/hlo/analysis/alias_info.h"
2727
#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
2828
#include "xla/hlo/ir/hlo_computation.h"
2929
#include "xla/hlo/ir/hlo_instruction.h"
@@ -40,9 +40,11 @@ class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault {
4040
public:
4141
explicit BFloat16ConversionFoldingVisitor(
4242
HloComputation* computation, const FloatSupport* bfloat16_support,
43+
const AliasInfo* alias_info,
4344
BFloat16ConversionFolding* bfloat16_conversion_folding)
4445
: computation_(computation),
4546
bfloat16_support_(bfloat16_support),
47+
alias_info_(alias_info),
4648
bfloat16_conversion_folding_(bfloat16_conversion_folding) {}
4749

4850
absl::Status DefaultAction(HloInstruction* hlo) override;
@@ -52,9 +54,10 @@ class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault {
5254

5355
static bool Run(HloComputation* computation,
5456
const FloatSupport* bfloat16_support,
57+
const AliasInfo* alias_info,
5558
BFloat16ConversionFolding* bfloat16_conversion_folding) {
56-
BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support,
57-
bfloat16_conversion_folding);
59+
BFloat16ConversionFoldingVisitor visitor(
60+
computation, bfloat16_support, alias_info, bfloat16_conversion_folding);
5861
CHECK_OK(computation->Accept(&visitor));
5962
return visitor.changed_;
6063
}
@@ -77,6 +80,7 @@ class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault {
7780

7881
HloComputation* computation_;
7982
const FloatSupport* bfloat16_support_;
83+
const AliasInfo* alias_info_;
8084
BFloat16ConversionFolding* bfloat16_conversion_folding_;
8185
bool changed_ = false;
8286
};
@@ -172,22 +176,22 @@ absl::Status BFloat16ConversionFoldingVisitor::DefaultAction(
172176
// Do not fold BF16 conversions for instructions related to tuples, entry and
173177
// exit of a computation, fusion, convert, side-effecting instructions,
174178
// in-place operations and control flow.
175-
if (hlo->opcode() == HloOpcode::kTuple || //
176-
hlo->opcode() == HloOpcode::kGetTupleElement || //
177-
hlo->opcode() == HloOpcode::kConstant || //
178-
hlo->opcode() == HloOpcode::kParameter || //
179-
hlo->opcode() == HloOpcode::kFusion || //
180-
hlo->opcode() == HloOpcode::kBitcast || //
181-
hlo->opcode() == HloOpcode::kBitcastConvert || //
182-
hlo->opcode() == HloOpcode::kConvert || //
183-
hlo->opcode() == HloOpcode::kCall || //
184-
hlo->opcode() == HloOpcode::kCustomCall || //
185-
hlo->opcode() == HloOpcode::kWhile || //
186-
hlo->opcode() == HloOpcode::kConditional || //
187-
hlo->opcode() == HloOpcode::kAsyncStart || //
188-
hlo->opcode() == HloOpcode::kAsyncDone || //
189-
hlo->opcode() == HloOpcode::kOptimizationBarrier || //
190-
!HloDataflowAnalysis::GetInPlaceInputOutputPairs(hlo).empty() || //
179+
if (hlo->opcode() == HloOpcode::kTuple || //
180+
hlo->opcode() == HloOpcode::kGetTupleElement || //
181+
hlo->opcode() == HloOpcode::kConstant || //
182+
hlo->opcode() == HloOpcode::kParameter || //
183+
hlo->opcode() == HloOpcode::kFusion || //
184+
hlo->opcode() == HloOpcode::kBitcast || //
185+
hlo->opcode() == HloOpcode::kBitcastConvert || //
186+
hlo->opcode() == HloOpcode::kConvert || //
187+
hlo->opcode() == HloOpcode::kCall || //
188+
hlo->opcode() == HloOpcode::kCustomCall || //
189+
hlo->opcode() == HloOpcode::kWhile || //
190+
hlo->opcode() == HloOpcode::kConditional || //
191+
hlo->opcode() == HloOpcode::kAsyncStart || //
192+
hlo->opcode() == HloOpcode::kAsyncDone || //
193+
hlo->opcode() == HloOpcode::kOptimizationBarrier || //
194+
!alias_info_->GetInPlaceInputOutputPairs(hlo).empty() || //
191195
hlo->HasSideEffectNoRecurse()) {
192196
return absl::OkStatus();
193197
}
@@ -275,7 +279,8 @@ absl::StatusOr<bool> BFloat16ConversionFolding::RunImpl(
275279
module->ToString());
276280
bool changed = false;
277281
for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
278-
if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_, this)) {
282+
if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_,
283+
alias_info_, this)) {
279284
changed = true;
280285
}
281286
}

third_party/xla/xla/hlo/transforms/simplifiers/bfloat16_conversion_folding.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020
#include "absl/log/check.h"
2121
#include "absl/status/statusor.h"
2222
#include "absl/strings/string_view.h"
23+
#include "xla/hlo/analysis/alias_info.h"
2324
#include "xla/hlo/ir/hlo_module.h"
2425
#include "xla/hlo/pass/hlo_pass_interface.h"
2526
#include "xla/service/float_support.h"
@@ -38,8 +39,9 @@ namespace xla {
3839
// changed made by this pass.
3940
class BFloat16ConversionFolding : public HloModulePass {
4041
public:
41-
explicit BFloat16ConversionFolding(const FloatSupport* bfloat16_support)
42-
: bfloat16_support_(bfloat16_support) {
42+
BFloat16ConversionFolding(const FloatSupport* bfloat16_support,
43+
const AliasInfo* alias_info)
44+
: bfloat16_support_(bfloat16_support), alias_info_(alias_info) {
4345
DCHECK(bfloat16_support->LowPrecisionType() == BF16);
4446
}
4547

@@ -55,6 +57,7 @@ class BFloat16ConversionFolding : public HloModulePass {
5557

5658
private:
5759
const FloatSupport* bfloat16_support_;
60+
const AliasInfo* alias_info_;
5861
};
5962

6063
} // namespace xla

third_party/xla/xla/hlo/transforms/simplifiers/bfloat16_conversion_folding_test.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include <optional>
2020

2121
#include "absl/status/statusor.h"
22+
#include "xla/hlo/analysis/alias_info.h"
2223
#include "xla/hlo/ir/hlo_computation.h"
2324
#include "xla/hlo/ir/hlo_instruction.h"
2425
#include "xla/hlo/ir/hlo_module.h"
@@ -80,11 +81,12 @@ class BFloat16ConversionFoldingTest : public HloHardwareIndependentTestBase {
8081

8182
bool FoldConversions(HloModule* module) {
8283
TestBFloat16Support bfloat16_support_;
83-
BFloat16ConversionFolding fold(&bfloat16_support_);
84+
BFloat16ConversionFolding fold(&bfloat16_support_, &alias_info_);
8485
absl::StatusOr<bool> result = fold.Run(module);
8586
EXPECT_IS_OK(result.status());
8687
return result.value();
8788
}
89+
AliasInfo alias_info_;
8890
};
8991

9092
TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) {

third_party/xla/xla/service/memory_space_assignment/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ cc_library(
597597
"//xla:shape_util",
598598
"//xla:util",
599599
"//xla:xla_data_proto_cc",
600+
"//xla/hlo/analysis:alias_info",
600601
"//xla/hlo/analysis:hlo_alias_analysis",
601602
"//xla/hlo/analysis:hlo_dataflow_analysis",
602603
"//xla/hlo/analysis:hlo_operand_index",

0 commit comments

Comments
 (0)