@@ -182,18 +182,17 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
182182 if (do_message_passing) {
183183 int nswap = lmp_list.nswap ;
184184
185- std::vector<int > sendnum_new (nswap, 0 );
186- std::vector<int > sendlist_new;
185+ std::vector<int > sendnum_new (nswap, 0 );
186+ std::vector<int > sendlist_new;
187187 sendlist_new.reserve (
188- std::accumulate (lmp_list.sendnum , lmp_list.sendnum + nswap, 0 )
189- );
188+ std::accumulate (lmp_list.sendnum , lmp_list.sendnum + nswap, 0 ));
190189 for (int s = 0 ; s < nswap; ++s) {
191190 int cnt = 0 ;
192191 for (int k = 0 ; k < lmp_list.sendnum [s]; ++k) {
193192 const int old_idx = lmp_list.sendlist [s][k];
194193 int mapped = (old_idx >= 0 && old_idx < (int )fwd_map.size ())
195- ? fwd_map[old_idx]
196- : -1 ;
194+ ? fwd_map[old_idx]
195+ : -1 ;
197196 if (mapped >= 0 ) {
198197 sendlist_new.push_back (mapped);
199198 ++cnt;
@@ -202,58 +201,60 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
202201 sendnum_new[s] = cnt;
203202 }
204203
205- std::vector<int > recvnum_new (nswap, 0 );
206- #ifdef MPI_FOUND
207- if (lmp_list.world ) {
208- MPI_Comm comm = *static_cast <MPI_Comm*>(lmp_list.world );
209- const int TAG_BASE = 0x7a31 ;
210- for (int s = 0 ; s < nswap; ++s) {
211- const int send_to = lmp_list.sendproc [s];
212- const int recv_from = lmp_list.recvproc [s];
213- int send_cnt = sendnum_new[s];
214- int recv_cnt = 0 ;
215- MPI_Sendrecv (&send_cnt, 1 , MPI_INT, send_to, TAG_BASE + s,
216- &recv_cnt, 1 , MPI_INT, recv_from, TAG_BASE + s,
217- comm, MPI_STATUS_IGNORE);
218- recvnum_new[s] = recv_cnt;
219- }
220- } else
221- #endif
222- {
223- for (int s = 0 ; s < nswap; ++s) recvnum_new[s] = sendnum_new[s];
224- }
204+ std::vector<int > recvnum_new (nswap, 0 );
205+ #ifdef MPI_FOUND
206+ if (lmp_list.world ) {
207+ MPI_Comm comm = *static_cast <MPI_Comm*>(lmp_list.world );
208+ const int TAG_BASE = 0x7a31 ;
209+ for (int s = 0 ; s < nswap; ++s) {
210+ const int send_to = lmp_list.sendproc [s];
211+ const int recv_from = lmp_list.recvproc [s];
212+ int send_cnt = sendnum_new[s];
213+ int recv_cnt = 0 ;
214+ MPI_Sendrecv (&send_cnt, 1 , MPI_INT, send_to, TAG_BASE + s, &recv_cnt,
215+ 1 , MPI_INT, recv_from, TAG_BASE + s, comm,
216+ MPI_STATUS_IGNORE);
217+ recvnum_new[s] = recv_cnt;
218+ }
219+ } else
220+ #endif
221+ {
222+ for (int s = 0 ; s < nswap; ++s) {
223+ recvnum_new[s] = sendnum_new[s];
224+ }
225+ }
225226
226- std::vector<int > firstrecv_new (nswap, 0 );
227- int acc = 0 ;
228- for (int s = 0 ; s < nswap; ++s) {
229- firstrecv_new[s] = acc;
230- acc += recvnum_new[s];
231- }
227+ std::vector<int > firstrecv_new (nswap, 0 );
228+ int acc = 0 ;
229+ for (int s = 0 ; s < nswap; ++s) {
230+ firstrecv_new[s] = acc;
231+ acc += recvnum_new[s];
232+ }
232233
233234 torch::Tensor sendproc_tensor =
234235 torch::from_blob (lmp_list.sendproc , {nswap}, int32_option);
235236 torch::Tensor recvproc_tensor =
236237 torch::from_blob (lmp_list.recvproc , {nswap}, int32_option);
237238
238- torch::Tensor firstrecv_tensor =
239- torch::from_blob (firstrecv_new.data (), {nswap}, int32_option).clone ();
240- torch::Tensor recvnum_tensor =
241- torch::from_blob (recvnum_new.data (), {nswap}, int32_option).clone ();
242- torch::Tensor sendnum_tensor =
243- torch::from_blob (sendnum_new.data (), {nswap}, int32_option).clone ();
244-
245- torch::Tensor sendlist_tensor =
246- torch::from_blob (sendlist_new.data (),
247- { static_cast <long >(sendlist_new.size ()) },
248- int32_option).clone ();
239+ torch::Tensor firstrecv_tensor =
240+ torch::from_blob (firstrecv_new.data (), {nswap}, int32_option).clone ();
241+ torch::Tensor recvnum_tensor =
242+ torch::from_blob (recvnum_new.data (), {nswap}, int32_option).clone ();
243+ torch::Tensor sendnum_tensor =
244+ torch::from_blob (sendnum_new.data (), {nswap}, int32_option).clone ();
249245
246+ torch::Tensor sendlist_tensor =
247+ torch::from_blob (sendlist_new.data (),
248+ {static_cast <long >(sendlist_new.size ())},
249+ int32_option)
250+ .clone ();
250251
251- // torch::Tensor firstrecv_tensor =
252- // torch::from_blob(lmp_list.firstrecv, {nswap}, int32_option);
253- // torch::Tensor recvnum_tensor =
254- // torch::from_blob(lmp_list.recvnum, {nswap}, int32_option);
255- // torch::Tensor sendnum_tensor =
256- // torch::from_blob(lmp_list.sendnum, {nswap}, int32_option);
252+ // torch::Tensor firstrecv_tensor =
253+ // torch::from_blob(lmp_list.firstrecv, {nswap}, int32_option);
254+ // torch::Tensor recvnum_tensor =
255+ // torch::from_blob(lmp_list.recvnum, {nswap}, int32_option);
256+ // torch::Tensor sendnum_tensor =
257+ // torch::from_blob(lmp_list.sendnum, {nswap}, int32_option);
257258 torch::Tensor communicator_tensor;
258259 if (lmp_list.world == 0 ) {
259260 communicator_tensor = torch::empty ({1 }, torch::kInt64 );
@@ -263,10 +264,12 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
263264 }
264265
265266 torch::Tensor nswap_tensor = torch::tensor (nswap, int32_option);
266- // int total_send =
267- // std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0);
268- // torch::Tensor sendlist_tensor =
269- // torch::from_blob(lmp_list.sendlist, {total_send}, int32_option);
267+ // int total_send =
268+ // std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap,
269+ // 0);
270+ // torch::Tensor sendlist_tensor =
271+ // torch::from_blob(lmp_list.sendlist, {total_send},
272+ // int32_option);
270273 comm_dict.insert_or_assign (" send_list" , sendlist_tensor);
271274 comm_dict.insert_or_assign (" send_proc" , sendproc_tensor);
272275 comm_dict.insert_or_assign (" recv_proc" , recvproc_tensor);
0 commit comments