88
99#pragma once
1010#include < deque>
11+ #include < fstream>
12+ #include < sstream>
1113#include < macis/asci/determinant_contributions.hpp>
1214#include < macis/asci/determinant_sort.hpp>
1315
1416namespace macis {
1517
1618template <size_t N>
17- double asci_pt2_constraint (wavefunction_iterator_t <N> cdets_begin,
19+ double asci_pt2_constraint (ASCISettings asci_settings,
20+ wavefunction_iterator_t <N> cdets_begin,
1821 wavefunction_iterator_t <N> cdets_end,
1922 const double E_ASCI , const std::vector<double >& C,
2023 size_t norb, const double * T_pq, const double * G_red,
@@ -39,26 +42,34 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
3942 : spdlog::stdout_color_mt (" asci_pt2" );
4043
4144 const size_t ncdets = std::distance (cdets_begin, cdets_end);
45+ logger->info (" [ASCI PT2 Settings]" );
46+ logger->info (" * NDETS = {}" , ncdets);
47+ logger->info (" * PT2_TOL = {}" , asci_settings.pt2_tol );
48+ logger->info (" * PT2_RESERVE_COUNT = {}" , asci_settings.pt2_reserve_count );
49+ logger->info (" * PT2_CONSTRAINT_LVL = {}" , asci_settings.pt2_constraint_level );
50+ logger->info (" * PT2_PRUNE = {}" , asci_settings.pt2_prune );
51+ logger->info (" * PT2_PRECMP_EPS = {}" , asci_settings.pt2_precompute_eps );
52+ logger->info (" " );
4253
4354 // For each unique alpha, create a list of beta string and store metadata
4455 struct beta_coeff_data {
4556 spin_wfn_type beta_string;
46- std::vector<uint32_t > occ_beta;
47- std::vector<uint32_t > vir_beta;
57+ std::vector<uint8_t > occ_beta;
58+ std::vector<uint8_t > vir_beta;
4859 std::vector<double > orb_ens_alpha;
4960 std::vector<double > orb_ens_beta;
5061 double coeff;
5162 double h_diag;
5263
5364 size_t mem () const {
5465 return sizeof (spin_wfn_type) +
55- (occ_beta.capacity () + vir_beta.capacity ()) * sizeof (uint32_t ) +
66+ (occ_beta.capacity () + vir_beta.capacity ()) * sizeof (uint8_t ) +
5667 (2 + orb_ens_alpha.capacity () + orb_ens_beta.capacity ()) * sizeof (double );
5768 }
5869
5970 beta_coeff_data (double c, size_t norb,
6071 const std::vector<uint32_t >& occ_alpha, wfn_t <N> w,
61- const HamiltonianGenerator<wfn_t <N>>& ham_gen) {
72+ const HamiltonianGenerator<wfn_t <N>>& ham_gen, bool pce ) {
6273 coeff = c;
6374
6475 beta_string = wfn_traits::beta_string (w);
@@ -67,16 +78,24 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
6778 h_diag = ham_gen.matrix_element (w, w);
6879
6980 // Compute occ/vir for beta string
70- spin_wfn_traits::state_to_occ_vir (norb, beta_string, occ_beta, vir_beta);
81+ std::vector<uint32_t > o_32, v_32;
82+ spin_wfn_traits::state_to_occ_vir (norb, beta_string, o_32, v_32);
83+ occ_beta.resize (o_32.size ());
84+ std::copy (o_32.begin (), o_32.end (), occ_beta.begin ());
85+ vir_beta.resize (v_32.size ());
86+ std::copy (v_32.begin (), v_32.end (), vir_beta.begin ());
7187
7288 // Precompute orbital energies
73- // orb_ens_alpha = ham_gen.single_orbital_ens(norb, occ_alpha, occ_beta);
74- // orb_ens_beta = ham_gen.single_orbital_ens(norb, occ_beta, occ_alpha);
89+ if (pce) {
90+ orb_ens_alpha = ham_gen.single_orbital_ens (norb, occ_alpha, o_32);
91+ orb_ens_beta = ham_gen.single_orbital_ens (norb, o_32, occ_alpha);
92+ }
7593 }
7694 };
7795
7896 auto uniq_alpha = get_unique_alpha (cdets_begin, cdets_end);
7997 const size_t nuniq_alpha = uniq_alpha.size ();
98+ logger->info (" * NUNIQ_ALPHA = {}" , nuniq_alpha);
8099
81100 using unique_alpha_data = std::vector<beta_coeff_data>;
82101 std::vector<unique_alpha_data> uad (nuniq_alpha);
@@ -89,7 +108,7 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
89108 uad[i].reserve (nbeta);
90109 for (auto j = 0 ; j < nbeta; ++j, ++iw) {
91110 const auto & w = *(cdets_begin + iw);
92- uad[i].emplace_back (C[iw], norb, occ_alpha, w, ham_gen);
111+ uad[i].emplace_back (C[iw], norb, occ_alpha, w, ham_gen,asci_settings. pt2_precompute_eps );
93112 }
94113 }
95114
@@ -105,7 +124,7 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
105124 }
106125 }
107126 printf (" MEM REQ ALPH = %.2e\n " , mem_alpha / gib);
108- printf (" MEM REQ CONT = %.2e\n " , 70000000 * sizeof (asci_contrib<wfn_t <N>>)/ 1024 ./ 1024 ./ 1024 );
127+ printf (" MEM REQ CONT = %.2e\n " , asci_settings. pt2_reserve_count * sizeof (asci_contrib<wfn_t <N>>)/ gib );
109128 }
110129 MPI_Barrier (comm);
111130
@@ -125,11 +144,29 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
125144 // auto constraints = dist_constraint_general<wfn_t<N>>(
126145 // 5, norb, n_sing_beta, n_doub_beta, uniq_alpha, comm);
127146 auto constraints = gen_constraints_general<wfn_t <N>>(
128- 10 , norb, n_sing_beta, n_doub_beta, uniq_alpha,
129- world_size * omp_get_max_threads ());
147+ asci_settings. pt2_constraint_level , norb, n_sing_beta,
148+ n_doub_beta, uniq_alpha, world_size * omp_get_max_threads (), 0 );
130149 auto gen_c_en = clock_type::now ();
131150 duration_type gen_c_dur = gen_c_en - gen_c_st;
132151 logger->info (" * GEN_DUR = {:.2e} ms" , gen_c_dur.count ());
152+ // if(!world_rank) {
153+ // std::ofstream c_file("constraint_work.txt");
154+ // std::stringstream ss;
155+ // for(auto [c,s] : constraints) {
156+ // ss << c.C() << " " << s << std::endl;
157+ // }
158+ // auto str = ss.str();
159+ // c_file.write(str.c_str(), str.size());
160+ // }
161+ // if(!world_rank) {
162+ // std::ofstream c_file("unique_alpha.txt");
163+ // std::stringstream ss;
164+ // for(size_t i = 0; i < nuniq_alpha; ++i) {
165+ // ss << uniq_alpha[i].first << " " << uniq_alpha[i].second << std::endl;
166+ // }
167+ // auto str = ss.str();
168+ // c_file.write(str.c_str(), str.size());
169+ // }
133170
134171 double EPT2 = 0.0 ;
135172 size_t NPT2 = 0 ;
@@ -138,14 +175,14 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
138175
139176 // Global atomic task-id counter
140177 global_atomic<size_t > nxtval (comm);
141- const double h_el_tol = 1e-6 ;
178+ const double h_el_tol = asci_settings. pt2_tol ;
142179
143180 auto pt2_st = clock_type::now ();
144181#pragma omp parallel reduction(+ : EPT2) reduction(+ : NPT2)
145182 {
146183 // Process ASCI pair contributions for each constraint
147184 asci_contrib_container<wfn_t <N>> asci_pairs;
148- asci_pairs.reserve (70000000ul );
185+ // asci_pairs.reserve(asci_settings.pt2_reserve_count );
149186 size_t ic = 0 ;
150187 while (ic < ncon_total) {
151188 // Atomically get the next task ID and increment for other
@@ -173,12 +210,22 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
173210 const auto c = C[iw];
174211 const auto & beta_det = bcd[j_beta].beta_string ;
175212 const auto h_diag = bcd[j_beta].h_diag ;
176- const auto & occ_beta = bcd[j_beta].occ_beta ;
177- const auto & vir_beta = bcd[j_beta].vir_beta ;
178- // const auto& orb_ens_alpha = bcd[j_beta].orb_ens_alpha;
179- // const auto& orb_ens_beta = bcd[j_beta].orb_ens_beta;
180- auto orb_ens_alpha = ham_gen.single_orbital_ens (norb, occ_alpha, occ_beta);
181- auto orb_ens_beta = ham_gen.single_orbital_ens (norb, occ_beta, occ_alpha);
213+
214+ // TODO: These copies are slow
215+ const auto & occ_beta_8 = bcd[j_beta].occ_beta ;
216+ const auto & vir_beta_8 = bcd[j_beta].vir_beta ;
217+ std::vector<uint32_t > occ_beta (occ_beta_8.size ()), vir_beta (vir_beta_8.size ());
218+ std::copy (occ_beta_8.begin (), occ_beta_8.end (), occ_beta.begin ());
219+ std::copy (vir_beta_8.begin (), vir_beta_8.end (), vir_beta.begin ());
220+
221+ std::vector<double > orb_ens_alpha, orb_ens_beta;
222+ if (asci_settings.pt2_precompute_eps ) {
223+ orb_ens_alpha = bcd[j_beta].orb_ens_alpha ;
224+ orb_ens_beta = bcd[j_beta].orb_ens_beta ;
225+ } else {
226+ orb_ens_alpha = ham_gen.single_orbital_ens (norb, occ_alpha, occ_beta);
227+ orb_ens_beta = ham_gen.single_orbital_ens (norb, occ_beta, occ_alpha);
228+ }
182229
183230 // AA excitations
184231 generate_constraint_singles_contributions_ss (
@@ -215,10 +262,12 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
215262 {w, std::numeric_limits<double >::infinity (), 1.0 });
216263 }
217264 }
218- if (asci_pairs.size () > 70000000 and asci_pairs.size () != old_pair_size) {
265+ if (asci_settings. pt2_prune and asci_pairs.size () > asci_settings. pt2_reserve_count and asci_pairs.size () != old_pair_size) {
219266 // Cleanup
220- auto uit = sort_and_accumulate_asci_pairs (asci_pairs.begin (),
267+ auto uit = stable_sort_and_accumulate_asci_pairs (asci_pairs.begin (),
221268 asci_pairs.end ());
269+ asci_pairs.erase (uit, asci_pairs.end ());
270+ uit = std::stable_partition (asci_pairs.begin (), asci_pairs.end (), [&](const auto & p){ return std::abs (p.pt2 ()) > h_el_tol; });
222271 asci_pairs.erase (uit, asci_pairs.end ());
223272 printf (" [rank %4d tid:%4d] IC = %lu / %lu IA = %lu / %lu SZ = %lu\n " , world_rank,
224273 omp_get_thread_num (), ic, ncon_total, i_alpha,
@@ -235,8 +284,7 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
235284 asci_pairs.end ());
236285 for (auto it = asci_pairs.begin (); it != uit; ++it) {
237286 if (!std::isinf (it->c_times_matel )) {
238- EPT2_local +=
239- (it->c_times_matel * it->c_times_matel ) / it->h_diag ;
287+ EPT2_local += it->pt2 ();
240288 NPT2_local++;
241289 }
242290 }
0 commit comments