@@ -437,11 +437,12 @@ void compiler::utils::Barrier::Run(llvm::ModuleAnalysisManager &mam) {
437437 bi_ = &mam.getResult <BuiltinInfoAnalysis>(module_);
438438 FindBarriers ();
439439
440+ kernel_id_map_[kBarrier_EndID ] = nullptr ;
441+
440442 if (barriers_.empty ()) {
441443 // If there are no barriers, we can use the original function as the
442444 // single barrier region.
443- barrier_graph.emplace_back ();
444- auto &node = barrier_graph.back ();
445+ auto &node = barrier_region_id_map_[kBarrier_FirstID ];
445446 node.entry = &func_.getEntryBlock ();
446447 node.id = kBarrier_FirstID ;
447448 node.successor_ids .push_back (kBarrier_EndID );
@@ -513,11 +514,9 @@ void compiler::utils::Barrier::FindBarriers() {
513514 if (callee != nullptr ) {
514515 const auto B = bi_->analyzeBuiltin (*callee);
515516 if (BuiltinInfo::isMuxBuiltinWithWGBarrierID (B.ID )) {
516- unsigned id = ~0u ;
517517 auto *const id_param = call_inst->getOperand (0 );
518- if (auto *const id_param_c = dyn_cast<ConstantInt>(id_param)) {
519- id = id_param_c->getZExtValue ();
520- }
518+ auto *const id_param_c = cast<ConstantInt>(id_param);
519+ const auto id = id_param_c->getZExtValue ();
521520 orderedBarriers.emplace_back (id, call_inst);
522521 }
523522 }
@@ -548,13 +547,15 @@ void compiler::utils::Barrier::SplitBlockwithBarrier() {
548547 exit_stub = MakeStubFunction (" __barrier_exit" , module_, stub_cc);
549548 }
550549
551- barrier_graph.emplace_back ();
552- auto &node = barrier_graph.back ();
550+ auto &node = barrier_region_id_map_[kBarrier_FirstID ];
553551 node.entry = &func_.getEntryBlock ();
554552 node.id = kBarrier_FirstID ;
555553
556- unsigned barrier_id = kBarrier_StartNewID ;
557554 for (CallInst *split_point : barriers_) {
555+ // ID identifying which barrier invoked stub used as argument to call.
556+ auto *id = cast<ConstantInt>(split_point->getOperand (0 ));
557+ const auto barrier_id = kBarrier_StartNewID + id->getZExtValue ();
558+
558559 if (is_debug_) {
559560 assert (entry_stub != nullptr ); // Guaranteed as is_debug_ is const.
560561 assert (exit_stub != nullptr ); // Guaranteed as is_debug_ is const.
@@ -564,10 +565,6 @@ void compiler::utils::Barrier::SplitBlockwithBarrier() {
564565 // them at a point where live variables have already been loaded. This
565566 // info won't be available till later.
566567
567- // ID identifying which barrier invoked stub used as argument to call.
568- // This number monotonically increases from 0 for each barrier.
569- auto id = ConstantInt::get (Type::getInt32Ty (module_.getContext ()),
570- barrier_id - kBarrier_StartNewID );
571568 // Call invoking entry stub
572569 auto entry_caller = CallInst::Create (entry_stub, id);
573570 entry_caller->setDebugLoc (split_point->getDebugLoc ());
@@ -583,10 +580,9 @@ void compiler::utils::Barrier::SplitBlockwithBarrier() {
583580 std::make_pair (entry_caller, exit_caller);
584581 }
585582
586- barrier_graph.emplace_back ();
587- auto &node = barrier_graph.back ();
583+ auto &node = barrier_region_id_map_[barrier_id];
588584 node.barrier_inst = split_point;
589- node.id = barrier_id++ ;
585+ node.id = barrier_id;
590586 node.schedule = getBarrierSchedule (*split_point);
591587
592588 // Our scan implementation requires a linear work-item ordering, to loop
@@ -603,7 +599,7 @@ void compiler::utils::Barrier::SplitBlockwithBarrier() {
603599 // We have to gather the basic block data after splitting, because we
604600 // might not be processing barriers in program order, and things can get
605601 // awfully confused.
606- for (auto &node : barrier_graph ) {
602+ for (auto &[i, node] : barrier_region_id_map_ ) {
607603 if (node.barrier_inst ) {
608604 auto *const bb = node.barrier_inst ->getParent ();
609605 barrier_id_map_[bb] = node.id ;
@@ -770,7 +766,7 @@ void compiler::utils::Barrier::FindLiveVariables() {
770766 }
771767 }
772768
773- for (auto ®ion : barrier_graph ) {
769+ for (auto &[i, region] : barrier_region_id_map_ ) {
774770 GatherBarrierRegionBlocks (region);
775771 GatherBarrierRegionUses (region, func_args);
776772 whole_live_variables_set_.set_union (region.uses_int );
@@ -1150,9 +1146,9 @@ Function *compiler::utils::Barrier::GenerateNewKernel(BarrierRegion ®ion) {
11501146 } else if (ReturnInst *ret =
11511147 dyn_cast<ReturnInst>(cloned_bb->getTerminator ())) {
11521148 // Change return instruction with end barrier number.
1153- ConstantInt *cst_zero =
1149+ ConstantInt *cst_endid =
11541150 ConstantInt::get (Type::getInt32Ty (context), kBarrier_EndID );
1155- ReturnInst *new_ret = ReturnInst::Create (context, cst_zero );
1151+ ReturnInst *new_ret = ReturnInst::Create (context, cst_endid );
11561152 new_ret->insertBefore (ret->getIterator ());
11571153 ret->replaceAllUsesWith (new_ret);
11581154 ret->eraseFromParent ();
@@ -1450,7 +1446,7 @@ BasicBlock *compiler::utils::Barrier::CloneBasicBlock(
14501446void compiler::utils::Barrier::SeperateKernelWithBarrier () {
14511447 if (barriers_.empty ()) return ;
14521448
1453- for (auto ®ion : barrier_graph ) {
1449+ for (auto &[i, region] : barrier_region_id_map_ ) {
14541450 kernel_id_map_[region.id ] = GenerateNewKernel (region);
14551451 }
14561452
@@ -1467,15 +1463,10 @@ void compiler::utils::Barrier::SeperateKernelWithBarrier() {
14671463
14681464 LLVM_DEBUG ({
14691465 for (const auto &Kid : kernel_id_map_) {
1470- dbgs () << " 1. kernel_id[" << Kid.first << " ] = " << Kid.second ->getName ()
1466+ dbgs () << " kernel_id[" << Kid.first << " ] = " << Kid.second ->getName ()
14711467 << " \n " ;
14721468 }
14731469
1474- for (unsigned I = kBarrier_FirstID ;
1475- I < kernel_id_map_.size () + kBarrier_FirstID ; I++) {
1476- dbgs () << " 2. kernel_id[" << I << " ] = " << kernel_id_map_[I]->getName ()
1477- << " \n " ;
1478- }
14791470 dbgs () << " \n\n " << module_ << " \n\n " ;
14801471 });
14811472}
0 commit comments