@@ -174,63 +174,37 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
174174 nlist_data.padding ();
175175 if (do_message_passing) {
176176 int nswap = lmp_list.nswap ;
177-
178- std::vector<int > sendnum_new (nswap, 0 );
179- std::vector<int > sendlist_new;
180- sendlist_new.reserve (
181- std::accumulate (lmp_list.sendnum , lmp_list.sendnum + nswap, 0 ));
177+ std::vector<std::vector<int >> sendlist_new;
178+ sendlist_new.resize (nswap);
179+ // select real atoms in sendlist
182180 for (int s = 0 ; s < nswap; ++s) {
183181 int cnt = 0 ;
182+ sendlist_new[s].reserve (lmp_list.sendnum [s]);
184183 for (int k = 0 ; k < lmp_list.sendnum [s]; ++k) {
185184 const int old_idx = lmp_list.sendlist [s][k];
186185 int mapped = (old_idx >= 0 && old_idx < (int )fwd_map.size ())
187186 ? fwd_map[old_idx]
188187 : -1 ;
189188 if (mapped >= 0 ) {
190- sendlist_new.push_back (mapped);
189+ sendlist_new[s] .push_back (mapped);
191190 ++cnt;
192191 }
193192 }
194- sendnum_new[s] = cnt;
195- }
196-
197- std::vector<int > recvnum_new (nswap, 0 );
198- // need check
199- for (int s = 0 ; s < nswap; ++s) {
200- recvnum_new[s] = sendnum_new[s];
201- }
202-
203- std::vector<int > firstrecv_new (nswap, 0 );
204- int acc = 0 ;
205- for (int s = 0 ; s < nswap; ++s) {
206- firstrecv_new[s] = acc;
207- acc += recvnum_new[s];
193+ std::memcpy (lmp_list.sendlist [s], sendlist_new[s].data (), cnt * sizeof (int ));
194+ lmp_list.sendnum [s] = cnt;
195+ lmp_list.recvnum [s] = cnt;
208196 }
209-
210197 torch::Tensor sendproc_tensor =
211198 torch::from_blob (lmp_list.sendproc , {nswap}, int32_option);
212199 torch::Tensor recvproc_tensor =
213200 torch::from_blob (lmp_list.recvproc , {nswap}, int32_option);
214201
215202 torch::Tensor firstrecv_tensor =
216- torch::from_blob (firstrecv_new. data () , {nswap}, int32_option). clone ( );
203+ torch::from_blob (lmp_list. firstrecv , {nswap}, int32_option);
217204 torch::Tensor recvnum_tensor =
218- torch::from_blob (recvnum_new. data () , {nswap}, int32_option). clone ( );
205+ torch::from_blob (lmp_list. recvnum , {nswap}, int32_option);
219206 torch::Tensor sendnum_tensor =
220- torch::from_blob (sendnum_new.data (), {nswap}, int32_option).clone ();
221-
222- torch::Tensor sendlist_tensor =
223- torch::from_blob (sendlist_new.data (),
224- {static_cast <long >(sendlist_new.size ())},
225- int32_option)
226- .clone ();
227-
228- // torch::Tensor firstrecv_tensor =
229- // torch::from_blob(lmp_list.firstrecv, {nswap}, int32_option);
230- // torch::Tensor recvnum_tensor =
231- // torch::from_blob(lmp_list.recvnum, {nswap}, int32_option);
232- // torch::Tensor sendnum_tensor =
233- // torch::from_blob(lmp_list.sendnum, {nswap}, int32_option);
207+ torch::from_blob (lmp_list.sendnum , {nswap}, int32_option);
234208 torch::Tensor communicator_tensor;
235209 if (lmp_list.world == 0 ) {
236210 communicator_tensor = torch::empty ({1 }, torch::kInt64 );
@@ -240,12 +214,13 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
240214 }
241215
242216 torch::Tensor nswap_tensor = torch::tensor (nswap, int32_option);
243- // int total_send =
244- // std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap,
245- // 0);
246- // torch::Tensor sendlist_tensor =
247- // torch::from_blob(lmp_list.sendlist, {total_send},
248- // int32_option);
217+
218+ int total_send =
219+ std::accumulate (lmp_list.sendnum , lmp_list.sendnum + nswap,
220+ 0 );
221+ torch::Tensor sendlist_tensor =
222+ torch::from_blob (lmp_list.sendlist , {total_send},
223+ int32_option);
249224 comm_dict.insert_or_assign (" send_list" , sendlist_tensor);
250225 comm_dict.insert_or_assign (" send_proc" , sendproc_tensor);
251226 comm_dict.insert_or_assign (" recv_proc" , recvproc_tensor);
0 commit comments