Skip to content

Commit a4f565f

Browse files
committed
Update DeepPotPT.cc
1 parent 4a2027e commit a4f565f

1 file changed

Lines changed: 18 additions & 43 deletions

File tree

source/api_cc/src/DeepPotPT.cc

Lines changed: 18 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -174,63 +174,37 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
174174
nlist_data.padding();
175175
if (do_message_passing) {
176176
int nswap = lmp_list.nswap;
177-
178-
std::vector<int> sendnum_new(nswap, 0);
179-
std::vector<int> sendlist_new;
180-
sendlist_new.reserve(
181-
std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0));
177+
std::vector<std::vector<int>> sendlist_new;
178+
sendlist_new.resize(nswap);
179+
// select real atoms in sendlist
182180
for (int s = 0; s < nswap; ++s) {
183181
int cnt = 0;
182+
sendlist_new[s].reserve(lmp_list.sendnum[s]);
184183
for (int k = 0; k < lmp_list.sendnum[s]; ++k) {
185184
const int old_idx = lmp_list.sendlist[s][k];
186185
int mapped = (old_idx >= 0 && old_idx < (int)fwd_map.size())
187186
? fwd_map[old_idx]
188187
: -1;
189188
if (mapped >= 0) {
190-
sendlist_new.push_back(mapped);
189+
sendlist_new[s].push_back(mapped);
191190
++cnt;
192191
}
193192
}
194-
sendnum_new[s] = cnt;
195-
}
196-
197-
std::vector<int> recvnum_new(nswap, 0);
198-
// need check
199-
for (int s = 0; s < nswap; ++s) {
200-
recvnum_new[s] = sendnum_new[s];
201-
}
202-
203-
std::vector<int> firstrecv_new(nswap, 0);
204-
int acc = 0;
205-
for (int s = 0; s < nswap; ++s) {
206-
firstrecv_new[s] = acc;
207-
acc += recvnum_new[s];
193+
std::memcpy(lmp_list.sendlist[s], sendlist_new[s].data(), cnt * sizeof(int));
194+
lmp_list.sendnum[s] = cnt;
195+
lmp_list.recvnum[s] = cnt;
208196
}
209-
210197
torch::Tensor sendproc_tensor =
211198
torch::from_blob(lmp_list.sendproc, {nswap}, int32_option);
212199
torch::Tensor recvproc_tensor =
213200
torch::from_blob(lmp_list.recvproc, {nswap}, int32_option);
214201

215202
torch::Tensor firstrecv_tensor =
216-
torch::from_blob(firstrecv_new.data(), {nswap}, int32_option).clone();
203+
torch::from_blob(lmp_list.firstrecv, {nswap}, int32_option);
217204
torch::Tensor recvnum_tensor =
218-
torch::from_blob(recvnum_new.data(), {nswap}, int32_option).clone();
205+
torch::from_blob(lmp_list.recvnum, {nswap}, int32_option);
219206
torch::Tensor sendnum_tensor =
220-
torch::from_blob(sendnum_new.data(), {nswap}, int32_option).clone();
221-
222-
torch::Tensor sendlist_tensor =
223-
torch::from_blob(sendlist_new.data(),
224-
{static_cast<long>(sendlist_new.size())},
225-
int32_option)
226-
.clone();
227-
228-
// torch::Tensor firstrecv_tensor =
229-
// torch::from_blob(lmp_list.firstrecv, {nswap}, int32_option);
230-
// torch::Tensor recvnum_tensor =
231-
// torch::from_blob(lmp_list.recvnum, {nswap}, int32_option);
232-
// torch::Tensor sendnum_tensor =
233-
// torch::from_blob(lmp_list.sendnum, {nswap}, int32_option);
207+
torch::from_blob(lmp_list.sendnum, {nswap}, int32_option);
234208
torch::Tensor communicator_tensor;
235209
if (lmp_list.world == 0) {
236210
communicator_tensor = torch::empty({1}, torch::kInt64);
@@ -240,12 +214,13 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
240214
}
241215

242216
torch::Tensor nswap_tensor = torch::tensor(nswap, int32_option);
243-
// int total_send =
244-
// std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap,
245-
// 0);
246-
// torch::Tensor sendlist_tensor =
247-
// torch::from_blob(lmp_list.sendlist, {total_send},
248-
// int32_option);
217+
218+
int total_send =
219+
std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap,
220+
0);
221+
torch::Tensor sendlist_tensor =
222+
torch::from_blob(lmp_list.sendlist, {total_send},
223+
int32_option);
249224
comm_dict.insert_or_assign("send_list", sendlist_tensor);
250225
comm_dict.insert_or_assign("send_proc", sendproc_tensor);
251226
comm_dict.insert_or_assign("recv_proc", recvproc_tensor);

0 commit comments

Comments
 (0)