1010#include " device.h"
1111#include " errors.h"
1212
13+ #ifdef USE_MPI
14+ #include < mpi.h>
15+ #ifdef OMPI_MPI_H
16+ #include < mpi-ext.h>
17+ #endif
18+ #endif
19+
1320using namespace deepmd ;
1421
1522void DeepPotPT::translate_error (std::function<void ()> f) {
@@ -174,16 +181,79 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
174181 nlist_data.padding ();
175182 if (do_message_passing) {
176183 int nswap = lmp_list.nswap ;
184+
185+ std::vector<int > sendnum_new (nswap, 0 );
186+ std::vector<int > sendlist_new;
187+ sendlist_new.reserve (
188+ std::accumulate (lmp_list.sendnum , lmp_list.sendnum + nswap, 0 )
189+ );
190+ for (int s = 0 ; s < nswap; ++s) {
191+ int cnt = 0 ;
192+ for (int k = 0 ; k < lmp_list.sendnum [s]; ++k) {
193+ const int old_idx = lmp_list.sendlist [s][k];
194+ int mapped = (old_idx >= 0 && old_idx < (int )fwd_map.size ())
195+ ? fwd_map[old_idx]
196+ : -1 ;
197+ if (mapped >= 0 ) {
198+ sendlist_new.push_back (mapped);
199+ ++cnt;
200+ }
201+ }
202+ sendnum_new[s] = cnt;
203+ }
204+
205+ std::vector<int > recvnum_new (nswap, 0 );
206+ #ifdef MPI_FOUND
207+ if (lmp_list.world ) {
208+ MPI_Comm comm = *static_cast <MPI_Comm*>(lmp_list.world );
209+ const int TAG_BASE = 0x7a31 ;
210+ for (int s = 0 ; s < nswap; ++s) {
211+ const int send_to = lmp_list.sendproc [s];
212+ const int recv_from = lmp_list.recvproc [s];
213+ int send_cnt = sendnum_new[s];
214+ int recv_cnt = 0 ;
215+ MPI_Sendrecv (&send_cnt, 1 , MPI_INT, send_to, TAG_BASE + s,
216+ &recv_cnt, 1 , MPI_INT, recv_from, TAG_BASE + s,
217+ comm, MPI_STATUS_IGNORE);
218+ recvnum_new[s] = recv_cnt;
219+ }
220+ } else
221+ #endif
222+ {
223+ for (int s = 0 ; s < nswap; ++s) recvnum_new[s] = sendnum_new[s];
224+ }
225+
226+ std::vector<int > firstrecv_new (nswap, 0 );
227+ int acc = 0 ;
228+ for (int s = 0 ; s < nswap; ++s) {
229+ firstrecv_new[s] = acc;
230+ acc += recvnum_new[s];
231+ }
232+
177233 torch::Tensor sendproc_tensor =
178234 torch::from_blob (lmp_list.sendproc , {nswap}, int32_option);
179235 torch::Tensor recvproc_tensor =
180236 torch::from_blob (lmp_list.recvproc , {nswap}, int32_option);
181- torch::Tensor firstrecv_tensor =
182- torch::from_blob (lmp_list.firstrecv , {nswap}, int32_option);
183- torch::Tensor recvnum_tensor =
184- torch::from_blob (lmp_list.recvnum , {nswap}, int32_option);
185- torch::Tensor sendnum_tensor =
186- torch::from_blob (lmp_list.sendnum , {nswap}, int32_option);
237+
238+ torch::Tensor firstrecv_tensor =
239+ torch::from_blob (firstrecv_new.data (), {nswap}, int32_option).clone ();
240+ torch::Tensor recvnum_tensor =
241+ torch::from_blob (recvnum_new.data (), {nswap}, int32_option).clone ();
242+ torch::Tensor sendnum_tensor =
243+ torch::from_blob (sendnum_new.data (), {nswap}, int32_option).clone ();
244+
245+ torch::Tensor sendlist_tensor =
246+ torch::from_blob (sendlist_new.data (),
247+ { static_cast <long >(sendlist_new.size ()) },
248+ int32_option).clone ();
249+
250+
251+ // torch::Tensor firstrecv_tensor =
252+ // torch::from_blob(lmp_list.firstrecv, {nswap}, int32_option);
253+ // torch::Tensor recvnum_tensor =
254+ // torch::from_blob(lmp_list.recvnum, {nswap}, int32_option);
255+ // torch::Tensor sendnum_tensor =
256+ // torch::from_blob(lmp_list.sendnum, {nswap}, int32_option);
187257 torch::Tensor communicator_tensor;
188258 if (lmp_list.world == 0 ) {
189259 communicator_tensor = torch::empty ({1 }, torch::kInt64 );
@@ -193,10 +263,10 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
193263 }
194264
195265 torch::Tensor nswap_tensor = torch::tensor (nswap, int32_option);
196- int total_send =
197- std::accumulate (lmp_list.sendnum , lmp_list.sendnum + nswap, 0 );
198- torch::Tensor sendlist_tensor =
199- torch::from_blob (lmp_list.sendlist , {total_send}, int32_option);
266+ // int total_send =
267+ // std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0);
268+ // torch::Tensor sendlist_tensor =
269+ // torch::from_blob(lmp_list.sendlist, {total_send}, int32_option);
200270 comm_dict.insert_or_assign (" send_list" , sendlist_tensor);
201271 comm_dict.insert_or_assign (" send_proc" , sendproc_tensor);
202272 comm_dict.insert_or_assign (" recv_proc" , recvproc_tensor);
0 commit comments