Skip to content

Commit 4274499

Browse files
Refactored NXTVAL functionality into a proper C++ type (global_atomic)
1 parent 2831b59 commit 4274499

2 files changed

Lines changed: 49 additions & 6 deletions

File tree

include/macis/asci/pt2.hpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
133133
// std::mutex print_barrier;
134134

135135
const size_t ncon_total = constraints.size();
136-
duration_type lock_wait_dur(0.0);
136+
#if 0
137137
MPI_Win window;
138138
// MPI_Win_create( &window_count, sizeof(size_t), sizeof(size_t),
139139
// MPI_INFO_NULL, comm, &window );
@@ -142,6 +142,9 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
142142
&window_buffer, &window);
143143
if(window == MPI_WIN_NULL) throw std::runtime_error("Window failed");
144144
MPI_Win_lock_all(MPI_MODE_NOCHECK, window);
145+
#else
146+
global_atomic<size_t> nxtval(comm);
147+
#endif
145148
// Process ASCI pair contributions for each constraint
146149
#pragma omp parallel reduction(+ : EPT2) reduction(+ : NPT2)
147150
{
@@ -152,8 +155,9 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
152155
size_t ic = 0;
153156
while(ic < ncon_total) {
154157
size_t ntake = ic < 1000 ? 1 : 10;
155-
MPI_Fetch_and_op(&ntake, &ic, MPI_UINT64_T, 0, 0, MPI_SUM, window);
156-
MPI_Win_flush(0, window);
158+
//MPI_Fetch_and_op(&ntake, &ic, MPI_UINT64_T, 0, 0, MPI_SUM, window);
159+
//MPI_Win_flush(0, window);
160+
ic = nxtval.fetch_and_add(ntake);
157161

158162
// Loop over assigned tasks
159163
const size_t c_end = std::min(ncon_total, ic + ntake);
@@ -259,12 +263,11 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
259263
} else {
260264
logger->info("* PT2_DUR = ${:.2e} ms", local_pt2_dur);
261265
}
262-
printf("[rank %d] WAIT_DUR = %.2e\n", world_rank, lock_wait_dur.count());
263266

264267
NPT2 = allreduce(NPT2, MPI_SUM, comm);
265268
logger->info("* NPT2 = {}", NPT2);
266-
MPI_Win_unlock_all(window);
267-
MPI_Win_free(&window);
269+
//MPI_Win_unlock_all(window);
270+
//MPI_Win_free(&window);
268271

269272
return EPT2;
270273
}

include/macis/util/mpi.hpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,5 +228,45 @@ struct mpi_traits<std::bitset<N>> {
228228
}
229229
};
230230

231+
232+
233+
234+
template <typename T>
235+
class global_atomic {
236+
MPI_Win window_;
237+
T* buffer_;
238+
239+
public:
240+
241+
global_atomic() = delete;
242+
243+
global_atomic(MPI_Comm comm) {
244+
MPI_Win_allocate(sizeof(T), sizeof(T), MPI_INFO_NULL, comm, &buffer_,
245+
&window_);
246+
if(window_ == MPI_WIN_NULL) {
247+
throw std::runtime_error("Window creation failed");
248+
}
249+
*buffer_ = 0;
250+
MPI_Win_lock_all(MPI_MODE_NOCHECK, window_);
251+
}
252+
253+
~global_atomic() noexcept {
254+
MPI_Win_unlock_all(window_);
255+
MPI_Win_free(&window_);
256+
}
257+
258+
global_atomic(const global_atomic&) = default;
259+
global_atomic(global_atomic&&) noexcept = default;
260+
261+
T fetch_and_add(T val) {
262+
T next_val;
263+
MPI_Fetch_and_op(&val, &next_val, mpi_traits<T>::datatype(), 0, 0, MPI_SUM,
264+
window_);
265+
MPI_Win_flush(0,window_);
266+
return next_val;
267+
}
268+
};
269+
270+
231271
} // namespace macis
232272
#endif

0 commit comments

Comments
 (0)