Skip to content

Commit ba8f52e

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent f99ad4d commit ba8f52e

1 file changed

Lines changed: 56 additions & 53 deletions

File tree

source/api_cc/src/DeepPotPT.cc

Lines changed: 56 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -182,18 +182,17 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
182182
if (do_message_passing) {
183183
int nswap = lmp_list.nswap;
184184

185-
std::vector<int> sendnum_new(nswap, 0);
186-
std::vector<int> sendlist_new;
185+
std::vector<int> sendnum_new(nswap, 0);
186+
std::vector<int> sendlist_new;
187187
sendlist_new.reserve(
188-
std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0)
189-
);
188+
std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0));
190189
for (int s = 0; s < nswap; ++s) {
191190
int cnt = 0;
192191
for (int k = 0; k < lmp_list.sendnum[s]; ++k) {
193192
const int old_idx = lmp_list.sendlist[s][k];
194193
int mapped = (old_idx >= 0 && old_idx < (int)fwd_map.size())
195-
? fwd_map[old_idx]
196-
: -1;
194+
? fwd_map[old_idx]
195+
: -1;
197196
if (mapped >= 0) {
198197
sendlist_new.push_back(mapped);
199198
++cnt;
@@ -202,58 +201,60 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
202201
sendnum_new[s] = cnt;
203202
}
204203

205-
std::vector<int> recvnum_new(nswap, 0);
206-
#ifdef MPI_FOUND
207-
if (lmp_list.world) {
208-
MPI_Comm comm = *static_cast<MPI_Comm*>(lmp_list.world);
209-
const int TAG_BASE = 0x7a31;
210-
for (int s = 0; s < nswap; ++s) {
211-
const int send_to = lmp_list.sendproc[s];
212-
const int recv_from = lmp_list.recvproc[s];
213-
int send_cnt = sendnum_new[s];
214-
int recv_cnt = 0;
215-
MPI_Sendrecv(&send_cnt, 1, MPI_INT, send_to, TAG_BASE + s,
216-
&recv_cnt, 1, MPI_INT, recv_from, TAG_BASE + s,
217-
comm, MPI_STATUS_IGNORE);
218-
recvnum_new[s] = recv_cnt;
219-
}
220-
} else
221-
#endif
222-
{
223-
for (int s = 0; s < nswap; ++s) recvnum_new[s] = sendnum_new[s];
224-
}
204+
std::vector<int> recvnum_new(nswap, 0);
205+
#ifdef MPI_FOUND
206+
if (lmp_list.world) {
207+
MPI_Comm comm = *static_cast<MPI_Comm*>(lmp_list.world);
208+
const int TAG_BASE = 0x7a31;
209+
for (int s = 0; s < nswap; ++s) {
210+
const int send_to = lmp_list.sendproc[s];
211+
const int recv_from = lmp_list.recvproc[s];
212+
int send_cnt = sendnum_new[s];
213+
int recv_cnt = 0;
214+
MPI_Sendrecv(&send_cnt, 1, MPI_INT, send_to, TAG_BASE + s, &recv_cnt,
215+
1, MPI_INT, recv_from, TAG_BASE + s, comm,
216+
MPI_STATUS_IGNORE);
217+
recvnum_new[s] = recv_cnt;
218+
}
219+
} else
220+
#endif
221+
{
222+
for (int s = 0; s < nswap; ++s) {
223+
recvnum_new[s] = sendnum_new[s];
224+
}
225+
}
225226

226-
std::vector<int> firstrecv_new(nswap, 0);
227-
int acc = 0;
228-
for (int s = 0; s < nswap; ++s) {
229-
firstrecv_new[s] = acc;
230-
acc += recvnum_new[s];
231-
}
227+
std::vector<int> firstrecv_new(nswap, 0);
228+
int acc = 0;
229+
for (int s = 0; s < nswap; ++s) {
230+
firstrecv_new[s] = acc;
231+
acc += recvnum_new[s];
232+
}
232233

233234
torch::Tensor sendproc_tensor =
234235
torch::from_blob(lmp_list.sendproc, {nswap}, int32_option);
235236
torch::Tensor recvproc_tensor =
236237
torch::from_blob(lmp_list.recvproc, {nswap}, int32_option);
237238

238-
torch::Tensor firstrecv_tensor =
239-
torch::from_blob(firstrecv_new.data(), {nswap}, int32_option).clone();
240-
torch::Tensor recvnum_tensor =
241-
torch::from_blob(recvnum_new.data(), {nswap}, int32_option).clone();
242-
torch::Tensor sendnum_tensor =
243-
torch::from_blob(sendnum_new.data(), {nswap}, int32_option).clone();
244-
245-
torch::Tensor sendlist_tensor =
246-
torch::from_blob(sendlist_new.data(),
247-
{ static_cast<long>(sendlist_new.size()) },
248-
int32_option).clone();
239+
torch::Tensor firstrecv_tensor =
240+
torch::from_blob(firstrecv_new.data(), {nswap}, int32_option).clone();
241+
torch::Tensor recvnum_tensor =
242+
torch::from_blob(recvnum_new.data(), {nswap}, int32_option).clone();
243+
torch::Tensor sendnum_tensor =
244+
torch::from_blob(sendnum_new.data(), {nswap}, int32_option).clone();
249245

246+
torch::Tensor sendlist_tensor =
247+
torch::from_blob(sendlist_new.data(),
248+
{static_cast<long>(sendlist_new.size())},
249+
int32_option)
250+
.clone();
250251

251-
// torch::Tensor firstrecv_tensor =
252-
// torch::from_blob(lmp_list.firstrecv, {nswap}, int32_option);
253-
// torch::Tensor recvnum_tensor =
254-
// torch::from_blob(lmp_list.recvnum, {nswap}, int32_option);
255-
// torch::Tensor sendnum_tensor =
256-
// torch::from_blob(lmp_list.sendnum, {nswap}, int32_option);
252+
// torch::Tensor firstrecv_tensor =
253+
// torch::from_blob(lmp_list.firstrecv, {nswap}, int32_option);
254+
// torch::Tensor recvnum_tensor =
255+
// torch::from_blob(lmp_list.recvnum, {nswap}, int32_option);
256+
// torch::Tensor sendnum_tensor =
257+
// torch::from_blob(lmp_list.sendnum, {nswap}, int32_option);
257258
torch::Tensor communicator_tensor;
258259
if (lmp_list.world == 0) {
259260
communicator_tensor = torch::empty({1}, torch::kInt64);
@@ -263,10 +264,12 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
263264
}
264265

265266
torch::Tensor nswap_tensor = torch::tensor(nswap, int32_option);
266-
// int total_send =
267-
// std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0);
268-
// torch::Tensor sendlist_tensor =
269-
// torch::from_blob(lmp_list.sendlist, {total_send}, int32_option);
267+
// int total_send =
268+
// std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap,
269+
// 0);
270+
// torch::Tensor sendlist_tensor =
271+
// torch::from_blob(lmp_list.sendlist, {total_send},
272+
// int32_option);
270273
comm_dict.insert_or_assign("send_list", sendlist_tensor);
271274
comm_dict.insert_or_assign("send_proc", sendproc_tensor);
272275
comm_dict.insert_or_assign("recv_proc", recvproc_tensor);

0 commit comments

Comments
 (0)