@@ -96,6 +96,10 @@ double asci_pt2_constraint(ASCISettings asci_settings,
9696 auto uniq_alpha = get_unique_alpha (cdets_begin, cdets_end);
9797 const size_t nuniq_alpha = uniq_alpha.size ();
9898 logger->info (" * NUNIQ_ALPHA = {}" , nuniq_alpha);
99+ std::vector<size_t > uniq_alpha_ioff (nuniq_alpha);
100+ std::transform_exclusive_scan (uniq_alpha.begin (), uniq_alpha.end (),
101+ uniq_alpha_ioff.begin (), 0ul , std::plus<size_t >(),
102+ [](const auto & p){ return p.second ; });
99103
100104 using unique_alpha_data = std::vector<beta_coeff_data>;
101105 std::vector<unique_alpha_data> uad (nuniq_alpha);
@@ -172,12 +176,154 @@ double asci_pt2_constraint(ASCISettings asci_settings,
172176 size_t NPT2 = 0 ;
173177
174178 const size_t ncon_total = constraints.size ();
179+ const size_t ncon_big = 250 ;
180+ const size_t ncon_small = ncon_total - ncon_big;
175181
176182 // Global atomic task-id counter
177- global_atomic<size_t > nxtval (comm);
183+ global_atomic<size_t > nxtval_big (comm, 0 );
184+ global_atomic<size_t > nxtval_small (comm, ncon_big);
178185 const double h_el_tol = asci_settings.pt2_tol ;
179186
180187 auto pt2_st = clock_type::now ();
188+ // Assign each "big" constraint to an MPI rank, thread over contributions
189+ {
190+ size_t ic = 0 ;
191+ while (ic < ncon_big) {
192+ // Atomically get the next task ID and increment for other
193+ // MPI ranks
194+ ic = nxtval_big.fetch_and_add (1 );
195+ if (ic >= ncon_big) continue ;
196+ printf (" [pt2_big rank %4d] %10lu / %10lu\n " , world_rank, ic, ncon_total);
197+ const auto & con = constraints[ic].first ;
198+
199+ asci_contrib_container<wfn_t <N>> asci_pairs_con;
200+ #pragma omp parallel
201+ {
202+ asci_contrib_container<wfn_t <N>> asci_pairs;
203+ #pragma omp for schedule(dynamic)
204+ for (size_t i_alpha = 0 ; i_alpha < nuniq_alpha; ++i_alpha) {
205+ const size_t old_pair_size = asci_pairs.size ();
206+ const auto & alpha_det = uniq_alpha[i_alpha].first ;
207+ const auto ncon_alpha = constraint_histogram (alpha_det,1 ,1 ,con);
208+ if (!ncon_alpha) continue ;
209+ const auto occ_alpha = bits_to_indices (alpha_det);
210+ const bool alpha_satisfies_con = satisfies_constraint (alpha_det, con);
211+
212+ const auto & bcd = uad[i_alpha];
213+ const size_t nbeta = bcd.size ();
214+ for (size_t j_beta = 0 ; j_beta < nbeta; ++j_beta) {
215+ const size_t iw = uniq_alpha_ioff[i_alpha] + j_beta;
216+ const auto w = *(cdets_begin + iw);
217+ const auto c = C[iw];
218+ const auto & beta_det = bcd[j_beta].beta_string ;
219+ const auto h_diag = bcd[j_beta].h_diag ;
220+
221+ // TODO: These copies are slow
222+ const auto & occ_beta_8 = bcd[j_beta].occ_beta ;
223+ const auto & vir_beta_8 = bcd[j_beta].vir_beta ;
224+ std::vector<uint32_t > occ_beta (occ_beta_8.size ()), vir_beta (vir_beta_8.size ());
225+ std::copy (occ_beta_8.begin (), occ_beta_8.end (), occ_beta.begin ());
226+ std::copy (vir_beta_8.begin (), vir_beta_8.end (), vir_beta.begin ());
227+
228+ std::vector<double > orb_ens_alpha, orb_ens_beta;
229+ if (asci_settings.pt2_precompute_eps ) {
230+ orb_ens_alpha = bcd[j_beta].orb_ens_alpha ;
231+ orb_ens_beta = bcd[j_beta].orb_ens_beta ;
232+ } else {
233+ orb_ens_alpha = ham_gen.single_orbital_ens (norb, occ_alpha, occ_beta);
234+ orb_ens_beta = ham_gen.single_orbital_ens (norb, occ_beta, occ_alpha);
235+ }
236+
237+ // AA excitations
238+ generate_constraint_singles_contributions_ss (
239+ c, w, con, occ_alpha, occ_beta, orb_ens_alpha.data (), T_pq,
240+ norb, G_red, norb, V_red, norb, h_el_tol, h_diag, E_ASCI ,
241+ ham_gen, asci_pairs);
242+
243+ // AAAA excitations
244+ generate_constraint_doubles_contributions_ss (
245+ c, w, con, occ_alpha, occ_beta, orb_ens_alpha.data (), G_pqrs,
246+ norb, h_el_tol, h_diag, E_ASCI , ham_gen, asci_pairs);
247+
248+ // AABB excitations
249+ generate_constraint_doubles_contributions_os (
250+ c, w, con, occ_alpha, occ_beta, vir_beta, orb_ens_alpha.data (),
251+ orb_ens_beta.data (), V_pqrs, norb, h_el_tol, h_diag, E_ASCI ,
252+ ham_gen, asci_pairs);
253+
254+ if (alpha_satisfies_con) {
255+ // BB excitations
256+ append_singles_asci_contributions<Spin::Beta>(
257+ c, w, beta_det, occ_beta, vir_beta, occ_alpha,
258+ orb_ens_beta.data (), T_pq, norb, G_red, norb, V_red, norb,
259+ h_el_tol, h_diag, E_ASCI , ham_gen, asci_pairs);
260+
261+ // BBBB excitations
262+ append_ss_doubles_asci_contributions<Spin::Beta>(
263+ c, w, beta_det, alpha_det, occ_beta, vir_beta, occ_alpha,
264+ orb_ens_beta.data (), G_pqrs, norb, h_el_tol, h_diag, E_ASCI ,
265+ ham_gen, asci_pairs);
266+
267+ // No excitation (push inf to remove from list)
268+ asci_pairs.push_back (
269+ {w, std::numeric_limits<double >::infinity (), 1.0 });
270+ }
271+ }
272+ #if 0
273+ if(asci_settings.pt2_prune and asci_pairs.size() > asci_settings.pt2_reserve_count and asci_pairs.size() != old_pair_size) {
274+ // Cleanup
275+ auto uit = stable_sort_and_accumulate_asci_pairs(asci_pairs.begin(),
276+ asci_pairs.end());
277+ asci_pairs.erase(uit, asci_pairs.end());
278+ //uit = std::stable_partition(asci_pairs.begin(), asci_pairs.end(), [&](const auto& p){ return std::abs(p.pt2()) > h_el_tol; });
279+ //asci_pairs.erase(uit, asci_pairs.end());
280+ printf("[pt2_prune rank %4d tid:%4d] IC = %lu / %lu IA = %lu / %lu SZ = %lu\n", world_rank,
281+ omp_get_thread_num(), ic, ncon_total, i_alpha,
282+ nuniq_alpha, asci_pairs.size());
283+ }
284+ #endif
285+
286+ } // Unique Alpha Loop
287+
288+ // S&A Thread local pairs
289+ sort_and_accumulate_asci_pairs (asci_pairs);
290+
291+
292+ // Insert
293+ #pragma omp critical
294+ {
295+ if (asci_pairs_con.size ()) {
296+ asci_pairs_con.reserve (asci_pairs.size () + asci_pairs_con.size ());
297+ asci_pairs_con.insert (asci_pairs_con.end (), asci_pairs.begin (), asci_pairs.end ());
298+ } else {
299+ asci_pairs_con = std::move (asci_pairs);
300+ }
301+ }
302+
303+ } // OpenMP
304+
305+ double EPT2_local = 0.0 ;
306+ size_t NPT2_local = 0 ;
307+ // Local S&A for each quad + update EPT2
308+ {
309+ auto uit = sort_and_accumulate_asci_pairs (asci_pairs_con.begin (),
310+ asci_pairs_con.end ());
311+ for (auto it = asci_pairs_con.begin (); it != uit; ++it) {
312+ if (!std::isinf (it->c_times_matel )) {
313+ EPT2_local += it->pt2 ();
314+ NPT2_local++;
315+ }
316+ }
317+ asci_pairs_con.clear ();
318+ }
319+
320+ EPT2 += EPT2_local;
321+ NPT2 += NPT2_local;
322+ } // Constraint "loop"
323+ } // "Big constraints"
324+
325+
326+ // Parallelize over both MPI + threads for "small" constraints
181327#pragma omp parallel reduction(+ : EPT2) reduction(+ : NPT2)
182328 {
183329 // Process ASCI pair contributions for each constraint
@@ -188,24 +334,27 @@ double asci_pt2_constraint(ASCISettings asci_settings,
188334 // Atomically get the next task ID and increment for other
189335 // MPI ranks and threads
190336 size_t ntake = ic < 1000 ? 1 : 10 ;
191- ic = nxtval .fetch_and_add (ntake);
337+ ic = nxtval_small .fetch_and_add (ntake);
192338
193339 // Loop over assigned tasks
194340 const size_t c_end = std::min (ncon_total, ic + ntake);
195341 for (; ic < c_end; ++ic) {
196342 const auto & con = constraints[ic].first ;
197- printf (" [rank %4d tid:%4d] %10lu / %10lu\n " , world_rank,
343+ printf (" [pt2_small rank %4d tid:%4d] %10lu / %10lu\n " , world_rank,
198344 omp_get_thread_num (), ic, ncon_total);
199345
200- for (size_t i_alpha = 0 , iw = 0 ; i_alpha < nuniq_alpha; ++i_alpha) {
346+ for (size_t i_alpha = 0 ; i_alpha < nuniq_alpha; ++i_alpha) {
201347 const size_t old_pair_size = asci_pairs.size ();
202348 const auto & alpha_det = uniq_alpha[i_alpha].first ;
349+ const auto ncon_alpha = constraint_histogram (alpha_det,1 ,1 ,con);
350+ if (!ncon_alpha) continue ;
203351 const auto occ_alpha = bits_to_indices (alpha_det);
204352 const bool alpha_satisfies_con = satisfies_constraint (alpha_det, con);
205353
206354 const auto & bcd = uad[i_alpha];
207355 const size_t nbeta = bcd.size ();
208- for (size_t j_beta = 0 ; j_beta < nbeta; ++j_beta, ++iw) {
356+ for (size_t j_beta = 0 ; j_beta < nbeta; ++j_beta) {
357+ const size_t iw = uniq_alpha_ioff[i_alpha] + j_beta;
209358 const auto w = *(cdets_begin + iw);
210359 const auto c = C[iw];
211360 const auto & beta_det = bcd[j_beta].beta_string ;
@@ -267,8 +416,8 @@ double asci_pt2_constraint(ASCISettings asci_settings,
267416 auto uit = stable_sort_and_accumulate_asci_pairs (asci_pairs.begin (),
268417 asci_pairs.end ());
269418 asci_pairs.erase (uit, asci_pairs.end ());
270- uit = std::stable_partition (asci_pairs.begin (), asci_pairs.end (), [&](const auto & p){ return std::abs (p.pt2 ()) > h_el_tol; });
271- asci_pairs.erase (uit, asci_pairs.end ());
419+ // uit = std::stable_partition(asci_pairs.begin(), asci_pairs.end(), [&](const auto& p){ return std::abs(p.pt2()) > h_el_tol; });
420+ // asci_pairs.erase(uit, asci_pairs.end());
272421 printf (" [rank %4d tid:%4d] IC = %lu / %lu IA = %lu / %lu SZ = %lu\n " , world_rank,
273422 omp_get_thread_num (), ic, ncon_total, i_alpha,
274423 nuniq_alpha, asci_pairs.size ());
@@ -289,6 +438,9 @@ double asci_pt2_constraint(ASCISettings asci_settings,
289438 }
290439 }
291440 asci_pairs.clear ();
441+ // Deallocate
442+ if (asci_pairs.capacity () > asci_settings.pt2_reserve_count )
443+ asci_contrib_container<wfn_t <N>>().swap (asci_pairs);
292444 }
293445
294446 EPT2 += EPT2_local;
0 commit comments