Skip to content

Commit f99ad4d

Browse files
committed
debug 4749
Replaces usage of lmp_list send/recv arrays with new vectors that map indices using fwd_map and synchronize counts via MPI. Updates tensor construction to use these new vectors, improving correctness and flexibility in distributed communication.
1 parent accc331 commit f99ad4d

1 file changed

Lines changed: 80 additions & 10 deletions

File tree

source/api_cc/src/DeepPotPT.cc

Lines changed: 80 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010
#include "device.h"
1111
#include "errors.h"
1212

13+
#ifdef USE_MPI
14+
#include <mpi.h>
15+
#ifdef OMPI_MPI_H
16+
#include <mpi-ext.h>
17+
#endif
18+
#endif
19+
1320
using namespace deepmd;
1421

1522
void DeepPotPT::translate_error(std::function<void()> f) {
@@ -174,16 +181,79 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
174181
nlist_data.padding();
175182
if (do_message_passing) {
176183
int nswap = lmp_list.nswap;
184+
185+
std::vector<int> sendnum_new(nswap, 0);
186+
std::vector<int> sendlist_new;
187+
sendlist_new.reserve(
188+
std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0)
189+
);
190+
for (int s = 0; s < nswap; ++s) {
191+
int cnt = 0;
192+
for (int k = 0; k < lmp_list.sendnum[s]; ++k) {
193+
const int old_idx = lmp_list.sendlist[s][k];
194+
int mapped = (old_idx >= 0 && old_idx < (int)fwd_map.size())
195+
? fwd_map[old_idx]
196+
: -1;
197+
if (mapped >= 0) {
198+
sendlist_new.push_back(mapped);
199+
++cnt;
200+
}
201+
}
202+
sendnum_new[s] = cnt;
203+
}
204+
205+
std::vector<int> recvnum_new(nswap, 0);
206+
#ifdef MPI_FOUND
207+
if (lmp_list.world) {
208+
MPI_Comm comm = *static_cast<MPI_Comm*>(lmp_list.world);
209+
const int TAG_BASE = 0x7a31;
210+
for (int s = 0; s < nswap; ++s) {
211+
const int send_to = lmp_list.sendproc[s];
212+
const int recv_from = lmp_list.recvproc[s];
213+
int send_cnt = sendnum_new[s];
214+
int recv_cnt = 0;
215+
MPI_Sendrecv(&send_cnt, 1, MPI_INT, send_to, TAG_BASE + s,
216+
&recv_cnt, 1, MPI_INT, recv_from, TAG_BASE + s,
217+
comm, MPI_STATUS_IGNORE);
218+
recvnum_new[s] = recv_cnt;
219+
}
220+
} else
221+
#endif
222+
{
223+
for (int s = 0; s < nswap; ++s) recvnum_new[s] = sendnum_new[s];
224+
}
225+
226+
std::vector<int> firstrecv_new(nswap, 0);
227+
int acc = 0;
228+
for (int s = 0; s < nswap; ++s) {
229+
firstrecv_new[s] = acc;
230+
acc += recvnum_new[s];
231+
}
232+
177233
torch::Tensor sendproc_tensor =
178234
torch::from_blob(lmp_list.sendproc, {nswap}, int32_option);
179235
torch::Tensor recvproc_tensor =
180236
torch::from_blob(lmp_list.recvproc, {nswap}, int32_option);
181-
torch::Tensor firstrecv_tensor =
182-
torch::from_blob(lmp_list.firstrecv, {nswap}, int32_option);
183-
torch::Tensor recvnum_tensor =
184-
torch::from_blob(lmp_list.recvnum, {nswap}, int32_option);
185-
torch::Tensor sendnum_tensor =
186-
torch::from_blob(lmp_list.sendnum, {nswap}, int32_option);
237+
238+
torch::Tensor firstrecv_tensor =
239+
torch::from_blob(firstrecv_new.data(), {nswap}, int32_option).clone();
240+
torch::Tensor recvnum_tensor =
241+
torch::from_blob(recvnum_new.data(), {nswap}, int32_option).clone();
242+
torch::Tensor sendnum_tensor =
243+
torch::from_blob(sendnum_new.data(), {nswap}, int32_option).clone();
244+
245+
torch::Tensor sendlist_tensor =
246+
torch::from_blob(sendlist_new.data(),
247+
{ static_cast<long>(sendlist_new.size()) },
248+
int32_option).clone();
249+
250+
251+
// torch::Tensor firstrecv_tensor =
252+
// torch::from_blob(lmp_list.firstrecv, {nswap}, int32_option);
253+
// torch::Tensor recvnum_tensor =
254+
// torch::from_blob(lmp_list.recvnum, {nswap}, int32_option);
255+
// torch::Tensor sendnum_tensor =
256+
// torch::from_blob(lmp_list.sendnum, {nswap}, int32_option);
187257
torch::Tensor communicator_tensor;
188258
if (lmp_list.world == 0) {
189259
communicator_tensor = torch::empty({1}, torch::kInt64);
@@ -193,10 +263,10 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
193263
}
194264

195265
torch::Tensor nswap_tensor = torch::tensor(nswap, int32_option);
196-
int total_send =
197-
std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0);
198-
torch::Tensor sendlist_tensor =
199-
torch::from_blob(lmp_list.sendlist, {total_send}, int32_option);
266+
// int total_send =
267+
// std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0);
268+
// torch::Tensor sendlist_tensor =
269+
// torch::from_blob(lmp_list.sendlist, {total_send}, int32_option);
200270
comm_dict.insert_or_assign("send_list", sendlist_tensor);
201271
comm_dict.insert_or_assign("send_proc", sendproc_tensor);
202272
comm_dict.insert_or_assign("recv_proc", recvproc_tensor);

0 commit comments

Comments
 (0)