Skip to content

Commit 17e9265

Browse files
adding safeguards to PT2 - avoid precomputation of orbital energies to save on memory
1 parent 139e789 commit 17e9265

1 file changed

Lines changed: 40 additions & 21 deletions

File tree

include/macis/asci/pt2.hpp

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
5050
double coeff;
5151
double h_diag;
5252

53+
size_t mem() const {
54+
return sizeof(spin_wfn_type) +
55+
(occ_beta.capacity() + vir_beta.capacity()) * sizeof(uint32_t) +
56+
(2 + orb_ens_alpha.capacity() + orb_ens_beta.capacity()) * sizeof(double);
57+
}
58+
5359
beta_coeff_data(double c, size_t norb,
5460
const std::vector<uint32_t>& occ_alpha, wfn_t<N> w,
5561
const HamiltonianGenerator<wfn_t<N>>& ham_gen) {
@@ -64,17 +70,13 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
6470
spin_wfn_traits::state_to_occ_vir(norb, beta_string, occ_beta, vir_beta);
6571

6672
// Precompute orbital energies
67-
orb_ens_alpha = ham_gen.single_orbital_ens(norb, occ_alpha, occ_beta);
68-
orb_ens_beta = ham_gen.single_orbital_ens(norb, occ_beta, occ_alpha);
73+
//orb_ens_alpha = ham_gen.single_orbital_ens(norb, occ_alpha, occ_beta);
74+
//orb_ens_beta = ham_gen.single_orbital_ens(norb, occ_beta, occ_alpha);
6975
}
7076
};
7177

7278
auto uniq_alpha = get_unique_alpha(cdets_begin, cdets_end);
7379
const size_t nuniq_alpha = uniq_alpha.size();
74-
std::vector<wfn_t<N>> uniq_alpha_wfn(nuniq_alpha);
75-
std::transform(
76-
uniq_alpha.begin(), uniq_alpha.end(), uniq_alpha_wfn.begin(),
77-
[](const auto& p) { return wfn_traits::from_spin(p.first, 0); });
7880

7981
using unique_alpha_data = std::vector<beta_coeff_data>;
8082
std::vector<unique_alpha_data> uad(nuniq_alpha);
@@ -91,6 +93,22 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
9193
}
9294
}
9395

96+
if(world_rank == 0) {
97+
constexpr double gib = 1024 * 1024 * 1024;
98+
printf("MEM REQ DETS = %.2e\n", ncdets * sizeof(wfn_t<N>) / gib);
99+
printf("MEM REQ C = %.2e\n", ncdets * sizeof(double) / gib);
100+
size_t mem_alpha = 0;
101+
for( auto i = 0ul; i < nuniq_alpha; ++i) {
102+
mem_alpha += sizeof(spin_wfn_type);
103+
for(auto j = 0ul; j < uad[i].size(); ++j) {
104+
mem_alpha += uad[i][j].mem();
105+
}
106+
}
107+
printf("MEM REQ ALPH = %.2e\n", mem_alpha / gib);
108+
printf("MEM REQ CONT = %.2e\n", 70000000 * sizeof(asci_contrib<wfn_t<N>>)/ 1024./1024./1024);
109+
}
110+
MPI_Barrier(comm);
111+
94112
const auto n_occ_alpha = spin_wfn_traits::count(uniq_alpha[0].first);
95113
const auto n_vir_alpha = norb - n_occ_alpha;
96114
const auto n_sing_alpha = n_occ_alpha * n_vir_alpha;
@@ -120,14 +138,14 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
120138

121139
// Global atomic task-id counter
122140
global_atomic<size_t> nxtval(comm);
123-
const double h_el_tol = 1e-16;
141+
const double h_el_tol = 1e-6;
124142

125143
auto pt2_st = clock_type::now();
126144
#pragma omp parallel reduction(+ : EPT2) reduction(+ : NPT2)
127145
{
128146
// Process ASCI pair contributions for each constraint
129147
asci_contrib_container<wfn_t<N>> asci_pairs;
130-
asci_pairs.reserve(100000000ul);
148+
asci_pairs.reserve(70000000ul);
131149
size_t ic = 0;
132150
while(ic < ncon_total) {
133151
// Atomically get the next task ID and increment for other
@@ -143,6 +161,7 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
143161
omp_get_thread_num(), ic, ncon_total);
144162

145163
for(size_t i_alpha = 0, iw = 0; i_alpha < nuniq_alpha; ++i_alpha) {
164+
const size_t old_pair_size = asci_pairs.size();
146165
const auto& alpha_det = uniq_alpha[i_alpha].first;
147166
const auto occ_alpha = bits_to_indices(alpha_det);
148167
const bool alpha_satisfies_con = satisfies_constraint(alpha_det, con);
@@ -156,8 +175,10 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
156175
const auto h_diag = bcd[j_beta].h_diag;
157176
const auto& occ_beta = bcd[j_beta].occ_beta;
158177
const auto& vir_beta = bcd[j_beta].vir_beta;
159-
const auto& orb_ens_alpha = bcd[j_beta].orb_ens_alpha;
160-
const auto& orb_ens_beta = bcd[j_beta].orb_ens_beta;
178+
//const auto& orb_ens_alpha = bcd[j_beta].orb_ens_alpha;
179+
//const auto& orb_ens_beta = bcd[j_beta].orb_ens_beta;
180+
auto orb_ens_alpha = ham_gen.single_orbital_ens(norb, occ_alpha, occ_beta);
181+
auto orb_ens_beta = ham_gen.single_orbital_ens(norb, occ_beta, occ_alpha);
161182

162183
// AA excitations
163184
generate_constraint_singles_contributions_ss(
@@ -194,17 +215,15 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
194215
{w, std::numeric_limits<double>::infinity(), 1.0});
195216
}
196217
}
197-
198-
// if(not (i_alpha%10)) {
199-
//// Cleanup
200-
// auto uit = sort_and_accumulate_asci_pairs(asci_pairs.begin(),
201-
// asci_pairs.end());
202-
// asci_pairs.erase(uit, asci_pairs.end());
203-
// printf("[rank %4d tid:%4d] IC = %lu / %lu IA = %lu / %lu SZ =
204-
// %lu\n", world_rank,
205-
// omp_get_thread_num(), ic, ncon_total, i_alpha,
206-
// nuniq_alpha, asci_pairs.size());
207-
// }
218+
if(asci_pairs.size() > 70000000 and asci_pairs.size() != old_pair_size) {
219+
// Cleanup
220+
auto uit = sort_and_accumulate_asci_pairs(asci_pairs.begin(),
221+
asci_pairs.end());
222+
asci_pairs.erase(uit, asci_pairs.end());
223+
printf("[rank %4d tid:%4d] IC = %lu / %lu IA = %lu / %lu SZ = %lu\n", world_rank,
224+
omp_get_thread_num(), ic, ncon_total, i_alpha,
225+
nuniq_alpha, asci_pairs.size());
226+
}
208227

209228
} // Unique Alpha Loop
210229

0 commit comments

Comments
 (0)