Skip to content

Commit d868a9e

Browse files
Expose additional PT2 options, refactor sort_and_accumulate, add coeffs into append_xyz_contributions
1 parent 17e9265 commit d868a9e

6 files changed

Lines changed: 155 additions & 46 deletions

File tree

include/macis/asci/determinant_contributions.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ struct asci_contrib {
2020
double h_diag;
2121

2222
auto rv() const { return c_times_matel / h_diag; }
23+
auto pt2() const { return rv() * c_times_matel; }
2324
};
2425

2526
template <typename WfnT>
@@ -48,7 +49,7 @@ void append_singles_asci_contributions(
4849
for(auto p : occ_othr) h_el += V_ov[p];
4950

5051
// Early Exit
51-
if(std::abs(h_el) < h_el_tol) continue;
52+
if(std::abs(coeff * h_el) < h_el_tol) continue;
5253

5354
// Calculate Excited Determinant
5455
auto ex_det = wfn_traits::template single_excitation_no_check<Sigma>(
@@ -97,7 +98,7 @@ void append_ss_doubles_asci_contributions(
9798
const auto jb = b + j * LDG;
9899
const auto G_aibj = G_ai[jb];
99100

100-
if(std::abs(G_aibj) < h_el_tol) continue;
101+
if(std::abs(coeff * G_aibj) < h_el_tol) continue;
101102

102103
#if 0
103104
// Calculate excited determinant string (spin)
@@ -165,7 +166,7 @@ void append_os_doubles_asci_contributions(
165166
const auto jb = b + j * LDV;
166167
const auto V_aibj = V_ai[jb * LDV2];
167168

168-
if(std::abs(V_aibj) < h_el_tol) continue;
169+
if(std::abs(coeff * V_aibj) < h_el_tol) continue;
169170

170171
double sign_beta = single_excitation_sign(state_beta, b, j);
171172
double sign = sign_alpha * sign_beta;

include/macis/asci/determinant_search.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ struct ASCISettings {
3939
double h_el_tol = 1e-8;
4040
double rv_prune_tol = 1e-8;
4141
size_t pair_size_max = 5e8;
42+
43+
double pt2_tol = 1e-16;
44+
size_t pt2_reserve_count = 70000000;
45+
bool pt2_prune = false;
46+
bool pt2_precompute_eps = false;
47+
4248
bool just_singles = false;
4349
size_t grow_factor = 8;
4450
size_t max_refine_iter = 6;
@@ -49,6 +55,7 @@ struct ASCISettings {
4955

5056
// bool dist_triplet_random = false;
5157
int constraint_level = 2; // Up To Quints
58+
int pt2_constraint_level = 5;
5259
};
5360

5461
template <size_t N>

include/macis/asci/determinant_sort.hpp

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#if __has_include(<boost/sort/pdqsort/pdqsort.hpp>)
1212
#define MACIS_USE_BOOST_SORT
1313
#include <boost/sort/pdqsort/pdqsort.hpp>
14+
#include <boost/sort/sort.hpp>
1415
#endif
1516

1617
namespace macis {
@@ -61,6 +62,31 @@ void reorder_ci_on_alpha(WfnIterator begin, WfnIterator end, double* C) {
6162
std::copy(reorder_C.begin(), reorder_C.end(), C);
6263
}
6364

65+
66+
template <typename PairIterator>
67+
PairIterator accumulate_asci_pairs(PairIterator pairs_begin,
68+
PairIterator pairs_end) {
69+
70+
// Accumulate the ASCI scores into first instance of unique bitstrings
71+
auto cur_it = pairs_begin;
72+
for(auto it = cur_it + 1; it != pairs_end; ++it) {
73+
// If iterate is not the one being tracked, update the iterator
74+
if(it->state != cur_it->state) {
75+
cur_it = it;
76+
}
77+
78+
// Accumulate
79+
else {
80+
cur_it->c_times_matel += it->c_times_matel;
81+
it->c_times_matel = NAN; // Zero out to expose potential bugs
82+
}
83+
}
84+
85+
// Remote duplicate bitstrings
86+
return std::unique(pairs_begin, pairs_end,
87+
[](auto x, auto y) { return x.state == y.state; });
88+
}
89+
6490
template <typename PairIterator>
6591
PairIterator sort_and_accumulate_asci_pairs(PairIterator pairs_begin,
6692
PairIterator pairs_end) {
@@ -80,24 +106,29 @@ PairIterator sort_and_accumulate_asci_pairs(PairIterator pairs_begin,
80106
#endif
81107
(pairs_begin, pairs_end, comparator);
82108

83-
// Accumulate the ASCI scores into first instance of unique bitstrings
84-
auto cur_it = pairs_begin;
85-
for(auto it = cur_it + 1; it != pairs_end; ++it) {
86-
// If iterate is not the one being tracked, update the iterator
87-
if(it->state != cur_it->state) {
88-
cur_it = it;
89-
}
109+
return accumulate_asci_pairs(pairs_begin, pairs_end);
110+
}
90111

91-
// Accumulate
92-
else {
93-
cur_it->c_times_matel += it->c_times_matel;
94-
it->c_times_matel = NAN; // Zero out to expose potential bugs
95-
}
96-
}
112+
template <typename PairIterator>
113+
PairIterator stable_sort_and_accumulate_asci_pairs(PairIterator pairs_begin,
114+
PairIterator pairs_end) {
115+
const size_t npairs = std::distance(pairs_begin, pairs_end);
97116

98-
// Remote duplicate bitstrings
99-
return std::unique(pairs_begin, pairs_end,
100-
[](auto x, auto y) { return x.state == y.state; });
117+
if(!npairs) return pairs_end;
118+
119+
auto comparator = [](const auto& x, const auto& y) {
120+
return bitset_less(x.state, y.state);
121+
};
122+
123+
// Sort by bitstring
124+
#ifdef MACIS_USE_BOOST_SORT
125+
boost::sort::flat_stable_sort
126+
#else
127+
std::stable_sort
128+
#endif
129+
(pairs_begin, pairs_end, comparator);
130+
131+
return accumulate_asci_pairs(pairs_begin, pairs_end);
101132
}
102133

103134
template <typename WfnT>

include/macis/asci/mask_constraints.hpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ auto dist_constraint_general(size_t nlevels, size_t norb, size_t ns_othr,
630630
template <typename WfnType, typename ContainerType>
631631
auto gen_constraints_general(size_t nlevels, size_t norb, size_t ns_othr,
632632
size_t nd_othr, const ContainerType& unique_alpha,
633-
int world_size) {
633+
int world_size, size_t nlevel_min = 0) {
634634
using wfn_traits = wavefunction_traits<WfnType>;
635635
using constraint_type = alpha_constraint<wfn_traits>;
636636
using string_type = typename constraint_type::constraint_type;
@@ -671,6 +671,21 @@ auto gen_constraints_general(size_t nlevels, size_t norb, size_t ns_othr,
671671
auto constraint = constraint_type::make_triplet(t_i, t_j, t_k);
672672
constraint_sizes.emplace_back(constraint, 0ul);
673673
}
674+
// Build up higher-order constraints as base if requested
675+
for(size_t ilevel = 0; ilevel < nlevel_min; ++ilevel) {
676+
std::vector cur_constraints = constraint_sizes;
677+
for(auto [c,nw] : cur_constraints) {
678+
const auto C_min = c.C_min();
679+
for(auto q_l = 0; q_l < C_min; ++q_l) {
680+
// Generate masks / counts
681+
string_type cn_C = c.C();
682+
cn_C.flip(q_l);
683+
string_type cn_B = c.B() >> (C_min - q_l);
684+
constraint_type c_next(cn_C, cn_B, q_l);
685+
constraint_sizes.emplace_back(c_next, 0ul);
686+
}
687+
}
688+
}
674689

675690
struct atomic_wrapper {
676691
std::atomic<size_t> value;
@@ -686,8 +701,10 @@ auto gen_constraints_general(size_t nlevels, size_t norb, size_t ns_othr,
686701
// Compute histogram
687702
const auto ntrip_full = constraint_sizes.size();
688703
std::vector<atomic_wrapper> constraint_work(ntrip_full, 0ul);
704+
int world_rank = comm_rank(MPI_COMM_WORLD);
689705
#pragma omp parallel for schedule(dynamic)
690-
for(auto i_trip = 0; i_trip < ntrip_full; ++i_trip) {
706+
for(auto i_trip = 0ul; i_trip < ntrip_full; ++i_trip) {
707+
if(!world_rank and !(i_trip%1000)) printf("cgen %lu / %lu\n", i_trip, ntrip_full);
691708
auto& [constraint, __nw] = constraint_sizes[i_trip];
692709
auto& c_nw = constraint_work[i_trip];
693710
size_t nw = 0;

include/macis/asci/pt2.hpp

Lines changed: 72 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,16 @@
88

99
#pragma once
1010
#include <deque>
11+
#include <fstream>
12+
#include <sstream>
1113
#include <macis/asci/determinant_contributions.hpp>
1214
#include <macis/asci/determinant_sort.hpp>
1315

1416
namespace macis {
1517

1618
template <size_t N>
17-
double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
19+
double asci_pt2_constraint(ASCISettings asci_settings,
20+
wavefunction_iterator_t<N> cdets_begin,
1821
wavefunction_iterator_t<N> cdets_end,
1922
const double E_ASCI, const std::vector<double>& C,
2023
size_t norb, const double* T_pq, const double* G_red,
@@ -39,26 +42,34 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
3942
: spdlog::stdout_color_mt("asci_pt2");
4043

4144
const size_t ncdets = std::distance(cdets_begin, cdets_end);
45+
logger->info("[ASCI PT2 Settings]");
46+
logger->info(" * NDETS = {}", ncdets);
47+
logger->info(" * PT2_TOL = {}", asci_settings.pt2_tol);
48+
logger->info(" * PT2_RESERVE_COUNT = {}", asci_settings.pt2_reserve_count);
49+
logger->info(" * PT2_CONSTRAINT_LVL = {}", asci_settings.pt2_constraint_level);
50+
logger->info(" * PT2_PRUNE = {}", asci_settings.pt2_prune);
51+
logger->info(" * PT2_PRECMP_EPS = {}", asci_settings.pt2_precompute_eps);
52+
logger->info("");
4253

4354
// For each unique alpha, create a list of beta string and store metadata
4455
struct beta_coeff_data {
4556
spin_wfn_type beta_string;
46-
std::vector<uint32_t> occ_beta;
47-
std::vector<uint32_t> vir_beta;
57+
std::vector<uint8_t> occ_beta;
58+
std::vector<uint8_t> vir_beta;
4859
std::vector<double> orb_ens_alpha;
4960
std::vector<double> orb_ens_beta;
5061
double coeff;
5162
double h_diag;
5263

5364
size_t mem() const {
5465
return sizeof(spin_wfn_type) +
55-
(occ_beta.capacity() + vir_beta.capacity()) * sizeof(uint32_t) +
66+
(occ_beta.capacity() + vir_beta.capacity()) * sizeof(uint8_t) +
5667
(2 + orb_ens_alpha.capacity() + orb_ens_beta.capacity()) * sizeof(double);
5768
}
5869

5970
beta_coeff_data(double c, size_t norb,
6071
const std::vector<uint32_t>& occ_alpha, wfn_t<N> w,
61-
const HamiltonianGenerator<wfn_t<N>>& ham_gen) {
72+
const HamiltonianGenerator<wfn_t<N>>& ham_gen, bool pce) {
6273
coeff = c;
6374

6475
beta_string = wfn_traits::beta_string(w);
@@ -67,16 +78,24 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
6778
h_diag = ham_gen.matrix_element(w, w);
6879

6980
// Compute occ/vir for beta string
70-
spin_wfn_traits::state_to_occ_vir(norb, beta_string, occ_beta, vir_beta);
81+
std::vector<uint32_t> o_32, v_32;
82+
spin_wfn_traits::state_to_occ_vir(norb, beta_string, o_32, v_32);
83+
occ_beta.resize(o_32.size());
84+
std::copy(o_32.begin(), o_32.end(), occ_beta.begin());
85+
vir_beta.resize(v_32.size());
86+
std::copy(v_32.begin(), v_32.end(), vir_beta.begin());
7187

7288
// Precompute orbital energies
73-
//orb_ens_alpha = ham_gen.single_orbital_ens(norb, occ_alpha, occ_beta);
74-
//orb_ens_beta = ham_gen.single_orbital_ens(norb, occ_beta, occ_alpha);
89+
if(pce) {
90+
orb_ens_alpha = ham_gen.single_orbital_ens(norb, occ_alpha, o_32);
91+
orb_ens_beta = ham_gen.single_orbital_ens(norb, o_32, occ_alpha);
92+
}
7593
}
7694
};
7795

7896
auto uniq_alpha = get_unique_alpha(cdets_begin, cdets_end);
7997
const size_t nuniq_alpha = uniq_alpha.size();
98+
logger->info(" * NUNIQ_ALPHA = {}", nuniq_alpha);
8099

81100
using unique_alpha_data = std::vector<beta_coeff_data>;
82101
std::vector<unique_alpha_data> uad(nuniq_alpha);
@@ -89,7 +108,7 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
89108
uad[i].reserve(nbeta);
90109
for(auto j = 0; j < nbeta; ++j, ++iw) {
91110
const auto& w = *(cdets_begin + iw);
92-
uad[i].emplace_back(C[iw], norb, occ_alpha, w, ham_gen);
111+
uad[i].emplace_back(C[iw], norb, occ_alpha, w, ham_gen,asci_settings.pt2_precompute_eps);
93112
}
94113
}
95114

@@ -105,7 +124,7 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
105124
}
106125
}
107126
printf("MEM REQ ALPH = %.2e\n", mem_alpha / gib);
108-
printf("MEM REQ CONT = %.2e\n", 70000000 * sizeof(asci_contrib<wfn_t<N>>)/ 1024./1024./1024);
127+
printf("MEM REQ CONT = %.2e\n", asci_settings.pt2_reserve_count * sizeof(asci_contrib<wfn_t<N>>)/ gib);
109128
}
110129
MPI_Barrier(comm);
111130

@@ -125,11 +144,29 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
125144
// auto constraints = dist_constraint_general<wfn_t<N>>(
126145
// 5, norb, n_sing_beta, n_doub_beta, uniq_alpha, comm);
127146
auto constraints = gen_constraints_general<wfn_t<N>>(
128-
10, norb, n_sing_beta, n_doub_beta, uniq_alpha,
129-
world_size * omp_get_max_threads());
147+
asci_settings.pt2_constraint_level, norb, n_sing_beta,
148+
n_doub_beta, uniq_alpha, world_size * omp_get_max_threads(), 0);
130149
auto gen_c_en = clock_type::now();
131150
duration_type gen_c_dur = gen_c_en - gen_c_st;
132151
logger->info(" * GEN_DUR = {:.2e} ms", gen_c_dur.count());
152+
//if(!world_rank) {
153+
// std::ofstream c_file("constraint_work.txt");
154+
// std::stringstream ss;
155+
// for(auto [c,s] : constraints) {
156+
// ss << c.C() << " " << s << std::endl;
157+
// }
158+
// auto str = ss.str();
159+
// c_file.write(str.c_str(), str.size());
160+
//}
161+
//if(!world_rank) {
162+
// std::ofstream c_file("unique_alpha.txt");
163+
// std::stringstream ss;
164+
// for(size_t i = 0; i < nuniq_alpha; ++i) {
165+
// ss << uniq_alpha[i].first << " " << uniq_alpha[i].second << std::endl;
166+
// }
167+
// auto str = ss.str();
168+
// c_file.write(str.c_str(), str.size());
169+
//}
133170

134171
double EPT2 = 0.0;
135172
size_t NPT2 = 0;
@@ -138,14 +175,14 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
138175

139176
// Global atomic task-id counter
140177
global_atomic<size_t> nxtval(comm);
141-
const double h_el_tol = 1e-6;
178+
const double h_el_tol = asci_settings.pt2_tol;
142179

143180
auto pt2_st = clock_type::now();
144181
#pragma omp parallel reduction(+ : EPT2) reduction(+ : NPT2)
145182
{
146183
// Process ASCI pair contributions for each constraint
147184
asci_contrib_container<wfn_t<N>> asci_pairs;
148-
asci_pairs.reserve(70000000ul);
185+
//asci_pairs.reserve(asci_settings.pt2_reserve_count);
149186
size_t ic = 0;
150187
while(ic < ncon_total) {
151188
// Atomically get the next task ID and increment for other
@@ -173,12 +210,22 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
173210
const auto c = C[iw];
174211
const auto& beta_det = bcd[j_beta].beta_string;
175212
const auto h_diag = bcd[j_beta].h_diag;
176-
const auto& occ_beta = bcd[j_beta].occ_beta;
177-
const auto& vir_beta = bcd[j_beta].vir_beta;
178-
//const auto& orb_ens_alpha = bcd[j_beta].orb_ens_alpha;
179-
//const auto& orb_ens_beta = bcd[j_beta].orb_ens_beta;
180-
auto orb_ens_alpha = ham_gen.single_orbital_ens(norb, occ_alpha, occ_beta);
181-
auto orb_ens_beta = ham_gen.single_orbital_ens(norb, occ_beta, occ_alpha);
213+
214+
// TODO: These copies are slow
215+
const auto& occ_beta_8 = bcd[j_beta].occ_beta;
216+
const auto& vir_beta_8 = bcd[j_beta].vir_beta;
217+
std::vector<uint32_t> occ_beta(occ_beta_8.size()), vir_beta(vir_beta_8.size());
218+
std::copy(occ_beta_8.begin(), occ_beta_8.end(), occ_beta.begin());
219+
std::copy(vir_beta_8.begin(), vir_beta_8.end(), vir_beta.begin());
220+
221+
std::vector<double> orb_ens_alpha, orb_ens_beta;
222+
if(asci_settings.pt2_precompute_eps) {
223+
orb_ens_alpha = bcd[j_beta].orb_ens_alpha;
224+
orb_ens_beta = bcd[j_beta].orb_ens_beta;
225+
} else {
226+
orb_ens_alpha = ham_gen.single_orbital_ens(norb, occ_alpha, occ_beta);
227+
orb_ens_beta = ham_gen.single_orbital_ens(norb, occ_beta, occ_alpha);
228+
}
182229

183230
// AA excitations
184231
generate_constraint_singles_contributions_ss(
@@ -215,10 +262,12 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
215262
{w, std::numeric_limits<double>::infinity(), 1.0});
216263
}
217264
}
218-
if(asci_pairs.size() > 70000000 and asci_pairs.size() != old_pair_size) {
265+
if(asci_settings.pt2_prune and asci_pairs.size() > asci_settings.pt2_reserve_count and asci_pairs.size() != old_pair_size) {
219266
// Cleanup
220-
auto uit = sort_and_accumulate_asci_pairs(asci_pairs.begin(),
267+
auto uit = stable_sort_and_accumulate_asci_pairs(asci_pairs.begin(),
221268
asci_pairs.end());
269+
asci_pairs.erase(uit, asci_pairs.end());
270+
uit = std::stable_partition(asci_pairs.begin(), asci_pairs.end(), [&](const auto& p){ return std::abs(p.pt2()) > h_el_tol; });
222271
asci_pairs.erase(uit, asci_pairs.end());
223272
printf("[rank %4d tid:%4d] IC = %lu / %lu IA = %lu / %lu SZ = %lu\n", world_rank,
224273
omp_get_thread_num(), ic, ncon_total, i_alpha,
@@ -235,8 +284,7 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
235284
asci_pairs.end());
236285
for(auto it = asci_pairs.begin(); it != uit; ++it) {
237286
if(!std::isinf(it->c_times_matel)) {
238-
EPT2_local +=
239-
(it->c_times_matel * it->c_times_matel) / it->h_diag;
287+
EPT2_local += it->pt2();
240288
NPT2_local++;
241289
}
242290
}

0 commit comments

Comments
 (0)