@@ -91,15 +91,6 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
9191 }
9292 }
9393
94- // if(world_rank == 0) {
95- // std::ofstream ofile("uniq_alpha.txt");
96- // for(auto [d, c] : uniq_alpha) {
97- // ofile << to_canonical_string(wfn_traits::from_spin(d,0)) << " " << c <<
98- // std::endl;
99- // }
100- // }
101-
102- // const auto n_occ_alpha = wfn_traits::count(uniq_alpha_wfn[0]);
10394 const auto n_occ_alpha = spin_wfn_traits::count (uniq_alpha[0 ].first );
10495 const auto n_vir_alpha = norb - n_occ_alpha;
10596 const auto n_sing_alpha = n_occ_alpha * n_vir_alpha;
@@ -121,56 +112,35 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
121112 duration_type gen_c_dur = gen_c_en - gen_c_st;
122113 logger->info (" * GEN_DUR = {:.2e} ms" , gen_c_dur.count ());
123114
124- size_t max_size = 100000000ul ;
125-
126115 double EPT2 = 0.0 ;
127116 size_t NPT2 = 0 ;
128- auto pt2_st = clock_type::now ();
129- std::deque<size_t > print_points (100 );
130- for (auto i = 0 ; i < 100 ; ++i) {
131- print_points[i] = constraints.size () * (i / 100 .);
132- }
133- // std::mutex print_barrier;
134117
135118 const size_t ncon_total = constraints.size ();
136- #if 0
137- MPI_Win window;
138- // MPI_Win_create( &window_count, sizeof(size_t), sizeof(size_t),
139- // MPI_INFO_NULL, comm, &window );
140- size_t* window_buffer;
141- MPI_Win_allocate(sizeof(size_t), sizeof(size_t), MPI_INFO_NULL, comm,
142- &window_buffer, &window);
143- if(window == MPI_WIN_NULL) throw std::runtime_error("Window failed");
144- MPI_Win_lock_all(MPI_MODE_NOCHECK, window);
145- #else
119+
120+ // Global atomic task-id counter
146121 global_atomic<size_t > nxtval (comm);
147- #endif
148- // Process ASCI pair contributions for each constraint
122+ const double h_el_tol = 1e-16 ;
123+
124+ auto pt2_st = clock_type::now ();
149125#pragma omp parallel reduction(+ : EPT2) reduction(+ : NPT2)
150126 {
127+ // Process ASCI pair contributions for each constraint
151128 asci_contrib_container<wfn_t <N>> asci_pairs;
152- asci_pairs.reserve (max_size);
153- // #pragma omp for
154- // for(size_t ic = 0; ic < constraints.size(); ++ic)
129+ asci_pairs.reserve (100000000ul );
155130 size_t ic = 0 ;
156131 while (ic < ncon_total) {
132+
133+ // Atomically get the next task ID and increment for other
134+ // MPI ranks and threads
157135 size_t ntake = ic < 1000 ? 1 : 10 ;
158- // MPI_Fetch_and_op(&ntake, &ic, MPI_UINT64_T, 0, 0, MPI_SUM, window);
159- // MPI_Win_flush(0, window);
160136 ic = nxtval.fetch_and_add (ntake);
161137
162138 // Loop over assigned tasks
163139 const size_t c_end = std::min (ncon_total, ic + ntake);
164140 for (; ic < c_end; ++ic) {
165141 const auto & con = constraints[ic].first ;
166- // if(ic >= print_points.front()) {
167- // //std::lock_guard<std::mutex> lock(print_barrier);
168- // printf("[rank %d] %.1f done\n", world_rank,
169- // double(ic)/constraints.size()*100); print_points.pop_front();
170- // }
171- printf (" [rank %4d tid:%4d] %lu / %lu\n " , world_rank,
142+ printf (" [rank %4d tid:%4d] %10lu / %10lu\n " , world_rank,
172143 omp_get_thread_num (), ic, ncon_total);
173- const double h_el_tol = 1e-16 ;
174144
175145 for (size_t i_alpha = 0 , iw = 0 ; i_alpha < nuniq_alpha; ++i_alpha) {
176146 const auto & alpha_det = uniq_alpha[i_alpha].first ;
@@ -234,7 +204,6 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
234204 auto uit = sort_and_accumulate_asci_pairs (asci_pairs.begin (),
235205 asci_pairs.end ());
236206 for (auto it = asci_pairs.begin (); it != uit; ++it) {
237- // if(std::find(cdets_begin, cdets_end, it->state) == cdets_end)
238207 if (!std::isinf (it->c_times_matel )) {
239208 EPT2_local +=
240209 (it->c_times_matel * it->c_times_matel ) / it->h_diag ;
@@ -248,7 +217,7 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
248217 NPT2 += NPT2_local;
249218 } // Loc constraint loop
250219 } // Constraint Loop
251- }
220+ } // OpenMP
252221 auto pt2_en = clock_type::now ();
253222
254223 EPT2 = allreduce (EPT2 , MPI_SUM , comm);
@@ -266,8 +235,6 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
266235
267236 NPT2 = allreduce (NPT2 , MPI_SUM , comm);
268237 logger->info (" * NPT2 = {}" , NPT2 );
269- // MPI_Win_unlock_all(window);
270- // MPI_Win_free(&window);
271238
272239 return EPT2 ;
273240}
0 commit comments