Skip to content

Commit fc5acc4

Browse files
Add big/small constraint seapration to PT2, accelerate case when alpha string does not connect to a configuration
1 parent d868a9e commit fc5acc4

1 file changed

Lines changed: 159 additions & 7 deletions

File tree

include/macis/asci/pt2.hpp

Lines changed: 159 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ double asci_pt2_constraint(ASCISettings asci_settings,
9696
auto uniq_alpha = get_unique_alpha(cdets_begin, cdets_end);
9797
const size_t nuniq_alpha = uniq_alpha.size();
9898
logger->info(" * NUNIQ_ALPHA = {}", nuniq_alpha);
99+
std::vector<size_t> uniq_alpha_ioff(nuniq_alpha);
100+
std::transform_exclusive_scan(uniq_alpha.begin(), uniq_alpha.end(),
101+
uniq_alpha_ioff.begin(), 0ul, std::plus<size_t>(),
102+
[](const auto& p){ return p.second; });
99103

100104
using unique_alpha_data = std::vector<beta_coeff_data>;
101105
std::vector<unique_alpha_data> uad(nuniq_alpha);
@@ -172,12 +176,154 @@ double asci_pt2_constraint(ASCISettings asci_settings,
172176
size_t NPT2 = 0;
173177

174178
const size_t ncon_total = constraints.size();
179+
const size_t ncon_big = 250;
180+
const size_t ncon_small = ncon_total - ncon_big;
175181

176182
// Global atomic task-id counter
177-
global_atomic<size_t> nxtval(comm);
183+
global_atomic<size_t> nxtval_big(comm, 0);
184+
global_atomic<size_t> nxtval_small(comm, ncon_big);
178185
const double h_el_tol = asci_settings.pt2_tol;
179186

180187
auto pt2_st = clock_type::now();
188+
// Assign each "big" constraint to an MPI rank, thread over contributions
189+
{
190+
size_t ic = 0;
191+
while(ic < ncon_big) {
192+
// Atomically get the next task ID and increment for other
193+
// MPI ranks
194+
ic = nxtval_big.fetch_and_add(1);
195+
if(ic >= ncon_big) continue;
196+
printf("[pt2_big rank %4d] %10lu / %10lu\n", world_rank, ic, ncon_total);
197+
const auto& con = constraints[ic].first;
198+
199+
asci_contrib_container<wfn_t<N>> asci_pairs_con;
200+
#pragma omp parallel
201+
{
202+
asci_contrib_container<wfn_t<N>> asci_pairs;
203+
#pragma omp for schedule(dynamic)
204+
for(size_t i_alpha = 0; i_alpha < nuniq_alpha; ++i_alpha) {
205+
const size_t old_pair_size = asci_pairs.size();
206+
const auto& alpha_det = uniq_alpha[i_alpha].first;
207+
const auto ncon_alpha = constraint_histogram(alpha_det,1,1,con);
208+
if(!ncon_alpha) continue;
209+
const auto occ_alpha = bits_to_indices(alpha_det);
210+
const bool alpha_satisfies_con = satisfies_constraint(alpha_det, con);
211+
212+
const auto& bcd = uad[i_alpha];
213+
const size_t nbeta = bcd.size();
214+
for(size_t j_beta = 0; j_beta < nbeta; ++j_beta) {
215+
const size_t iw = uniq_alpha_ioff[i_alpha] + j_beta;
216+
const auto w = *(cdets_begin + iw);
217+
const auto c = C[iw];
218+
const auto& beta_det = bcd[j_beta].beta_string;
219+
const auto h_diag = bcd[j_beta].h_diag;
220+
221+
// TODO: These copies are slow
222+
const auto& occ_beta_8 = bcd[j_beta].occ_beta;
223+
const auto& vir_beta_8 = bcd[j_beta].vir_beta;
224+
std::vector<uint32_t> occ_beta(occ_beta_8.size()), vir_beta(vir_beta_8.size());
225+
std::copy(occ_beta_8.begin(), occ_beta_8.end(), occ_beta.begin());
226+
std::copy(vir_beta_8.begin(), vir_beta_8.end(), vir_beta.begin());
227+
228+
std::vector<double> orb_ens_alpha, orb_ens_beta;
229+
if(asci_settings.pt2_precompute_eps) {
230+
orb_ens_alpha = bcd[j_beta].orb_ens_alpha;
231+
orb_ens_beta = bcd[j_beta].orb_ens_beta;
232+
} else {
233+
orb_ens_alpha = ham_gen.single_orbital_ens(norb, occ_alpha, occ_beta);
234+
orb_ens_beta = ham_gen.single_orbital_ens(norb, occ_beta, occ_alpha);
235+
}
236+
237+
// AA excitations
238+
generate_constraint_singles_contributions_ss(
239+
c, w, con, occ_alpha, occ_beta, orb_ens_alpha.data(), T_pq,
240+
norb, G_red, norb, V_red, norb, h_el_tol, h_diag, E_ASCI,
241+
ham_gen, asci_pairs);
242+
243+
// AAAA excitations
244+
generate_constraint_doubles_contributions_ss(
245+
c, w, con, occ_alpha, occ_beta, orb_ens_alpha.data(), G_pqrs,
246+
norb, h_el_tol, h_diag, E_ASCI, ham_gen, asci_pairs);
247+
248+
// AABB excitations
249+
generate_constraint_doubles_contributions_os(
250+
c, w, con, occ_alpha, occ_beta, vir_beta, orb_ens_alpha.data(),
251+
orb_ens_beta.data(), V_pqrs, norb, h_el_tol, h_diag, E_ASCI,
252+
ham_gen, asci_pairs);
253+
254+
if(alpha_satisfies_con) {
255+
// BB excitations
256+
append_singles_asci_contributions<Spin::Beta>(
257+
c, w, beta_det, occ_beta, vir_beta, occ_alpha,
258+
orb_ens_beta.data(), T_pq, norb, G_red, norb, V_red, norb,
259+
h_el_tol, h_diag, E_ASCI, ham_gen, asci_pairs);
260+
261+
// BBBB excitations
262+
append_ss_doubles_asci_contributions<Spin::Beta>(
263+
c, w, beta_det, alpha_det, occ_beta, vir_beta, occ_alpha,
264+
orb_ens_beta.data(), G_pqrs, norb, h_el_tol, h_diag, E_ASCI,
265+
ham_gen, asci_pairs);
266+
267+
// No excitation (push inf to remove from list)
268+
asci_pairs.push_back(
269+
{w, std::numeric_limits<double>::infinity(), 1.0});
270+
}
271+
}
272+
#if 0
273+
if(asci_settings.pt2_prune and asci_pairs.size() > asci_settings.pt2_reserve_count and asci_pairs.size() != old_pair_size) {
274+
// Cleanup
275+
auto uit = stable_sort_and_accumulate_asci_pairs(asci_pairs.begin(),
276+
asci_pairs.end());
277+
asci_pairs.erase(uit, asci_pairs.end());
278+
//uit = std::stable_partition(asci_pairs.begin(), asci_pairs.end(), [&](const auto& p){ return std::abs(p.pt2()) > h_el_tol; });
279+
//asci_pairs.erase(uit, asci_pairs.end());
280+
printf("[pt2_prune rank %4d tid:%4d] IC = %lu / %lu IA = %lu / %lu SZ = %lu\n", world_rank,
281+
omp_get_thread_num(), ic, ncon_total, i_alpha,
282+
nuniq_alpha, asci_pairs.size());
283+
}
284+
#endif
285+
286+
} // Unique Alpha Loop
287+
288+
// S&A Thread local pairs
289+
sort_and_accumulate_asci_pairs(asci_pairs);
290+
291+
292+
// Insert
293+
#pragma omp critical
294+
{
295+
if(asci_pairs_con.size()) {
296+
asci_pairs_con.reserve(asci_pairs.size() + asci_pairs_con.size());
297+
asci_pairs_con.insert(asci_pairs_con.end(), asci_pairs.begin(), asci_pairs.end());
298+
} else {
299+
asci_pairs_con = std::move(asci_pairs);
300+
}
301+
}
302+
303+
} // OpenMP
304+
305+
double EPT2_local = 0.0;
306+
size_t NPT2_local = 0;
307+
// Local S&A for each quad + update EPT2
308+
{
309+
auto uit = sort_and_accumulate_asci_pairs(asci_pairs_con.begin(),
310+
asci_pairs_con.end());
311+
for(auto it = asci_pairs_con.begin(); it != uit; ++it) {
312+
if(!std::isinf(it->c_times_matel)) {
313+
EPT2_local += it->pt2();
314+
NPT2_local++;
315+
}
316+
}
317+
asci_pairs_con.clear();
318+
}
319+
320+
EPT2 += EPT2_local;
321+
NPT2 += NPT2_local;
322+
} // Constraint "loop"
323+
} // "Big constraints"
324+
325+
326+
// Parallelize over both MPI + threads for "small" constraints
181327
#pragma omp parallel reduction(+ : EPT2) reduction(+ : NPT2)
182328
{
183329
// Process ASCI pair contributions for each constraint
@@ -188,24 +334,27 @@ double asci_pt2_constraint(ASCISettings asci_settings,
188334
// Atomically get the next task ID and increment for other
189335
// MPI ranks and threads
190336
size_t ntake = ic < 1000 ? 1 : 10;
191-
ic = nxtval.fetch_and_add(ntake);
337+
ic = nxtval_small.fetch_and_add(ntake);
192338

193339
// Loop over assigned tasks
194340
const size_t c_end = std::min(ncon_total, ic + ntake);
195341
for(; ic < c_end; ++ic) {
196342
const auto& con = constraints[ic].first;
197-
printf("[rank %4d tid:%4d] %10lu / %10lu\n", world_rank,
343+
printf("[pt2_small rank %4d tid:%4d] %10lu / %10lu\n", world_rank,
198344
omp_get_thread_num(), ic, ncon_total);
199345

200-
for(size_t i_alpha = 0, iw = 0; i_alpha < nuniq_alpha; ++i_alpha) {
346+
for(size_t i_alpha = 0; i_alpha < nuniq_alpha; ++i_alpha) {
201347
const size_t old_pair_size = asci_pairs.size();
202348
const auto& alpha_det = uniq_alpha[i_alpha].first;
349+
const auto ncon_alpha = constraint_histogram(alpha_det,1,1,con);
350+
if(!ncon_alpha) continue;
203351
const auto occ_alpha = bits_to_indices(alpha_det);
204352
const bool alpha_satisfies_con = satisfies_constraint(alpha_det, con);
205353

206354
const auto& bcd = uad[i_alpha];
207355
const size_t nbeta = bcd.size();
208-
for(size_t j_beta = 0; j_beta < nbeta; ++j_beta, ++iw) {
356+
for(size_t j_beta = 0; j_beta < nbeta; ++j_beta) {
357+
const size_t iw = uniq_alpha_ioff[i_alpha] + j_beta;
209358
const auto w = *(cdets_begin + iw);
210359
const auto c = C[iw];
211360
const auto& beta_det = bcd[j_beta].beta_string;
@@ -267,8 +416,8 @@ double asci_pt2_constraint(ASCISettings asci_settings,
267416
auto uit = stable_sort_and_accumulate_asci_pairs(asci_pairs.begin(),
268417
asci_pairs.end());
269418
asci_pairs.erase(uit, asci_pairs.end());
270-
uit = std::stable_partition(asci_pairs.begin(), asci_pairs.end(), [&](const auto& p){ return std::abs(p.pt2()) > h_el_tol; });
271-
asci_pairs.erase(uit, asci_pairs.end());
419+
//uit = std::stable_partition(asci_pairs.begin(), asci_pairs.end(), [&](const auto& p){ return std::abs(p.pt2()) > h_el_tol; });
420+
//asci_pairs.erase(uit, asci_pairs.end());
272421
printf("[rank %4d tid:%4d] IC = %lu / %lu IA = %lu / %lu SZ = %lu\n", world_rank,
273422
omp_get_thread_num(), ic, ncon_total, i_alpha,
274423
nuniq_alpha, asci_pairs.size());
@@ -289,6 +438,9 @@ double asci_pt2_constraint(ASCISettings asci_settings,
289438
}
290439
}
291440
asci_pairs.clear();
441+
// Deallocate
442+
if(asci_pairs.capacity() > asci_settings.pt2_reserve_count)
443+
asci_contrib_container<wfn_t<N>>().swap(asci_pairs);
292444
}
293445

294446
EPT2 += EPT2_local;

0 commit comments

Comments
 (0)