@@ -3079,7 +3079,10 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
30793079 reverse, {}, call, {}, &alias_info_));
30803080}
30813081
3082- using GetInPlaceInputOutputPairsTest = HloHardwareIndependentTestBase;
3082+ class GetInPlaceInputOutputPairsTest : public HloHardwareIndependentTestBase {
3083+ protected:
3084+ AliasInfo alias_info_;
3085+ };
30833086
30843087TEST_F (GetInPlaceInputOutputPairsTest, DUS) {
30853088 const char * kModule = R"(
@@ -3095,7 +3098,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, DUS) {
30953098 TF_ASSERT_OK_AND_ASSIGN (auto module , ParseAndReturnVerifiedModule (kModule ));
30963099 HloInstruction* dus = module ->entry_computation ()->root_instruction ();
30973100
3098- auto in_place_pairs = HloDataflowAnalysis:: GetInPlaceInputOutputPairs (dus);
3101+ auto in_place_pairs = alias_info_. GetInPlaceInputOutputPairs (dus);
30993102 std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
31003103 expected_pairs.push_back ({HloOperandIndex{0 , {}}, {}});
31013104 EXPECT_EQ (in_place_pairs, expected_pairs);
@@ -3122,7 +3125,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, DUSFusion) {
31223125 TF_ASSERT_OK_AND_ASSIGN (auto module , ParseAndReturnVerifiedModule (kModule ));
31233126 HloInstruction* fusion = module ->entry_computation ()->root_instruction ();
31243127
3125- auto in_place_pairs = HloDataflowAnalysis:: GetInPlaceInputOutputPairs (fusion);
3128+ auto in_place_pairs = alias_info_. GetInPlaceInputOutputPairs (fusion);
31263129 std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
31273130 expected_pairs.push_back ({HloOperandIndex{0 , {}}, {}});
31283131 EXPECT_EQ (in_place_pairs, expected_pairs);
@@ -3150,7 +3153,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, DUSFusionWithOutputOperandAliasing) {
31503153 TF_ASSERT_OK_AND_ASSIGN (auto module , ParseAndReturnVerifiedModule (kModule ));
31513154 HloInstruction* fusion = module ->entry_computation ()->root_instruction ();
31523155
3153- auto in_place_pairs = HloDataflowAnalysis:: GetInPlaceInputOutputPairs (fusion);
3156+ auto in_place_pairs = alias_info_. GetInPlaceInputOutputPairs (fusion);
31543157 std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
31553158 expected_pairs.push_back ({HloOperandIndex{0 , {}}, {1 }}); // discovered
31563159 expected_pairs.push_back ({HloOperandIndex{1 , {}}, {0 }}); // annotated
@@ -3176,7 +3179,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, NonDUSFusion) {
31763179 TF_ASSERT_OK_AND_ASSIGN (auto module , ParseAndReturnVerifiedModule (kModule ));
31773180 HloInstruction* fusion = module ->entry_computation ()->root_instruction ();
31783181
3179- auto in_place_pairs = HloDataflowAnalysis:: GetInPlaceInputOutputPairs (fusion);
3182+ auto in_place_pairs = alias_info_. GetInPlaceInputOutputPairs (fusion);
31803183 EXPECT_THAT (in_place_pairs, IsEmpty ());
31813184}
31823185
@@ -3198,7 +3201,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, NonDUSFusionWithOutputOperandAliasing) {
31983201 )" ;
31993202 TF_ASSERT_OK_AND_ASSIGN (auto module , ParseAndReturnVerifiedModule (kModule ));
32003203 HloInstruction* fusion = module ->entry_computation ()->root_instruction ();
3201- auto in_place_pairs = HloDataflowAnalysis:: GetInPlaceInputOutputPairs (fusion);
3204+ auto in_place_pairs = alias_info_. GetInPlaceInputOutputPairs (fusion);
32023205
32033206 std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
32043207 expected_pairs.push_back ({HloOperandIndex{0 , {}}, {}});
@@ -3233,7 +3236,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, NestedDUSFusion) {
32333236 TF_ASSERT_OK_AND_ASSIGN (auto module , ParseAndReturnVerifiedModule (kModule ));
32343237 HloInstruction* fusion = module ->entry_computation ()->root_instruction ();
32353238
3236- auto in_place_pairs = HloDataflowAnalysis:: GetInPlaceInputOutputPairs (fusion);
3239+ auto in_place_pairs = alias_info_. GetInPlaceInputOutputPairs (fusion);
32373240 std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
32383241 expected_pairs.push_back ({HloOperandIndex{0 , {}}, {}});
32393242 EXPECT_EQ (in_place_pairs, expected_pairs);
@@ -3277,12 +3280,12 @@ TEST_F(GetInPlaceInputOutputPairsTest, NestedMultiOutputDUSFusion) {
32773280 HloInstruction* inner_fusion = FindInstruction (module .get (), " inner_fusion" );
32783281
32793282 auto inner_in_place_pairs =
3280- HloDataflowAnalysis:: GetInPlaceInputOutputPairs (inner_fusion);
3283+ alias_info_. GetInPlaceInputOutputPairs (inner_fusion);
32813284 std::vector<std::pair<HloOperandIndex, ShapeIndex>> inner_expected_pairs;
32823285 inner_expected_pairs.push_back ({HloOperandIndex{1 , {1 }}, {1 }});
32833286 EXPECT_EQ (inner_in_place_pairs, inner_expected_pairs);
32843287
3285- auto in_place_pairs = HloDataflowAnalysis:: GetInPlaceInputOutputPairs (fusion);
3288+ auto in_place_pairs = alias_info_. GetInPlaceInputOutputPairs (fusion);
32863289 std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
32873290 expected_pairs.push_back ({HloOperandIndex{1 , {0 }}, {2 }});
32883291 EXPECT_EQ (in_place_pairs, expected_pairs);
@@ -3319,7 +3322,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, NestedLoopWithAliasingInDUSFusion) {
33193322 TF_ASSERT_OK_AND_ASSIGN (auto module , ParseAndReturnVerifiedModule (kModule ));
33203323 HloInstruction* fusion = module ->entry_computation ()->root_instruction ();
33213324
3322- auto in_place_pairs = HloDataflowAnalysis:: GetInPlaceInputOutputPairs (fusion);
3325+ auto in_place_pairs = alias_info_. GetInPlaceInputOutputPairs (fusion);
33233326 std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
33243327 expected_pairs.push_back ({HloOperandIndex{0 , {0 }}, {}});
33253328 EXPECT_EQ (in_place_pairs, expected_pairs);
@@ -3372,7 +3375,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, DUSLoopFusionWithCollective) {
33723375 )" ;
33733376 TF_ASSERT_OK_AND_ASSIGN (auto module , ParseAndReturnVerifiedModule (kModule ));
33743377 HloInstruction* fusion = module ->entry_computation ()->root_instruction ();
3375- auto in_place_pairs = HloDataflowAnalysis:: GetInPlaceInputOutputPairs (fusion);
3378+ auto in_place_pairs = alias_info_. GetInPlaceInputOutputPairs (fusion);
33763379 std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
33773380 expected_pairs.push_back ({HloOperandIndex{0 , {}}, {1 }});
33783381 EXPECT_EQ (in_place_pairs, expected_pairs);
@@ -3423,7 +3426,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, DUSOutputFusionWithCollective) {
34233426 )" ;
34243427 TF_ASSERT_OK_AND_ASSIGN (auto module , ParseAndReturnVerifiedModule (kModule ));
34253428 HloInstruction* fusion = module ->entry_computation ()->root_instruction ();
3426- auto in_place_pairs = HloDataflowAnalysis:: GetInPlaceInputOutputPairs (fusion);
3429+ auto in_place_pairs = alias_info_. GetInPlaceInputOutputPairs (fusion);
34273430 std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
34283431 expected_pairs.push_back ({HloOperandIndex{0 , {}}, {1 }});
34293432 EXPECT_EQ (in_place_pairs, expected_pairs);
@@ -3456,7 +3459,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, DUSLoopFusionWithBitcast) {
34563459 )" ;
34573460 TF_ASSERT_OK_AND_ASSIGN (auto module , ParseAndReturnVerifiedModule (kModule ));
34583461 HloInstruction* fusion = module ->entry_computation ()->root_instruction ();
3459- auto in_place_pairs = HloDataflowAnalysis:: GetInPlaceInputOutputPairs (fusion);
3462+ auto in_place_pairs = alias_info_. GetInPlaceInputOutputPairs (fusion);
34603463 std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
34613464 // p1 should be aliased with fusion1
34623465 expected_pairs.push_back ({HloOperandIndex{1 , {}}, {}});
@@ -3485,7 +3488,7 @@ ENTRY AllToAll {
34853488 module ->entry_computation ()->root_instruction ();
34863489
34873490 auto in_place_pairs =
3488- HloDataflowAnalysis:: GetInPlaceInputOutputPairs (ragged_all_to_all);
3491+ alias_info_. GetInPlaceInputOutputPairs (ragged_all_to_all);
34893492 std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
34903493 expected_pairs.push_back ({HloOperandIndex{1 , {}}, {}});
34913494 EXPECT_EQ (in_place_pairs, expected_pairs);
@@ -3607,8 +3610,7 @@ TEST_F(GetInPlaceInputOutputPairsTest, nvshmem_ar) {
36073610 const HloInstruction* ar_start =
36083611 module ->entry_computation ()->root_instruction ()->operand (0 );
36093612
3610- auto in_place_pairs =
3611- HloDataflowAnalysis::GetInPlaceInputOutputPairs (ar_start);
3613+ auto in_place_pairs = alias_info_.GetInPlaceInputOutputPairs (ar_start);
36123614 std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
36133615 // For nvshmem allreduce, we expect no aliasing for input and output buffers
36143616 // therefore empty inplace pairs.
0 commit comments