Skip to content

Commit 3068a24

Browse files
committed
Revert commit
This reverts commit b415c67.
1 parent b415c67 commit 3068a24

8 files changed

Lines changed: 33 additions & 87 deletions

File tree

source/api_cc/include/DeepPotPD.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -392,10 +392,6 @@ class DeepPotPD : public DeepPotBackend {
392392
bool gpu_enabled;
393393
std::unique_ptr<paddle_infer::Tensor> firstneigh_tensor;
394394
std::unique_ptr<paddle_infer::Tensor> mapping_tensor;
395-
// Owned comm data for remapped sendlist (NULL type filtering)
396-
std::vector<std::vector<int>> comm_sendlist_;
397-
std::vector<int> comm_sendnum_;
398-
std::vector<int> comm_recvnum_;
399395
};
400396

401397
} // namespace deepmd

source/api_cc/include/DeepPotPT.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -340,11 +340,6 @@ class DeepPotPT : public DeepPotBackend {
340340
at::Tensor firstneigh_tensor;
341341
c10::optional<torch::Tensor> mapping_tensor;
342342
torch::Dict<std::string, torch::Tensor> comm_dict;
343-
// Owned comm data for remapped sendlist (NULL type filtering)
344-
std::vector<std::vector<int>> comm_sendlist_;
345-
std::vector<int> comm_sendnum_;
346-
std::vector<int> comm_recvnum_;
347-
std::vector<int*> comm_sendlist_ptrs_;
348343
bool profiler_enabled{false};
349344
std::string profiler_file;
350345
/**

source/api_cc/include/DeepSpinPT.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -262,11 +262,6 @@ class DeepSpinPT : public DeepSpinBackend {
262262
at::Tensor firstneigh_tensor;
263263
c10::optional<torch::Tensor> mapping_tensor;
264264
torch::Dict<std::string, torch::Tensor> comm_dict;
265-
// Owned comm data for remapped sendlist (NULL type filtering)
266-
std::vector<std::vector<int>> comm_sendlist_;
267-
std::vector<int> comm_sendnum_;
268-
std::vector<int> comm_recvnum_;
269-
std::vector<int*> comm_sendlist_ptrs_;
270265
/**
271266
* @brief Translate PyTorch exceptions to the DeePMD-kit exception.
272267
* @param[in] f The function to run.

source/api_cc/include/common.h

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -153,20 +153,6 @@ void select_map_inv(typename std::vector<VT>::iterator out,
153153
const std::vector<int>& fwd_map,
154154
const int& stride);
155155

156-
/**
157-
* @brief Build filtered/remapped sendlist for NULL type filtering.
158-
* @param[out] new_sendlist Remapped sendlist vectors (one per swap).
159-
* @param[out] new_sendnum Number of real atoms to send per swap.
160-
* @param[out] new_recvnum Number of real atoms to receive per swap.
161-
* @param[in] inlist The input neighbor list with comm data.
162-
* @param[in] fwd_map Forward map from original to remapped indices.
163-
*/
164-
void select_real_atoms_sendlist(std::vector<std::vector<int>>& new_sendlist,
165-
std::vector<int>& new_sendnum,
166-
std::vector<int>& new_recvnum,
167-
const InputNlist& inlist,
168-
const std::vector<int>& fwd_map);
169-
170156
/**
171157
* @brief Get the number of threads from the environment variable.
172158
* @details A warning will be thrown if environment variables are not set.

source/api_cc/src/DeepPotPD.cc

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -399,29 +399,33 @@ void DeepPotPD::compute(ENERGYVTYPE& ener,
399399
auto sendlist_tensor = predictor_fl->GetInputHandle("send_list");
400400

401401
int nswap = lmp_list.nswap;
402-
select_real_atoms_sendlist(comm_sendlist_, comm_sendnum_, comm_recvnum_,
403-
lmp_list, fwd_map);
404402
sendproc_tensor->Reshape({nswap});
405403
sendproc_tensor->CopyFromCpu(lmp_list.sendproc);
406404

407405
recvproc_tensor->Reshape({nswap});
408406
recvproc_tensor->CopyFromCpu(lmp_list.recvproc);
409407

410408
recvnum_tensor->Reshape({nswap});
411-
recvnum_tensor->CopyFromCpu(comm_recvnum_.data());
409+
recvnum_tensor->CopyFromCpu(lmp_list.recvnum);
412410

413411
sendnum_tensor->Reshape({nswap});
414-
sendnum_tensor->CopyFromCpu(comm_sendnum_.data());
412+
if (sizeof(lmp_list.sendnum[0]) != sizeof(int32_t)) {
413+
std::vector<int32_t> temp_data(nswap);
414+
for (int i = 0; i < nswap; i++) {
415+
temp_data[i] = static_cast<int32_t>(lmp_list.sendnum[i]);
416+
}
417+
sendnum_tensor->CopyFromCpu(temp_data.data());
418+
} else {
419+
sendnum_tensor->CopyFromCpu(lmp_list.sendnum);
420+
}
415421
communicator_tensor->Reshape({1});
416422
if (lmp_list.world) {
417423
communicator_tensor->CopyFromCpu(static_cast<int*>(lmp_list.world));
418424
}
419425

420426
assert(sizeof(std::intptr_t) == 8);
421-
int total_send = 0;
422-
for (int s = 0; s < nswap; ++s) {
423-
total_send += comm_sendnum_[s];
424-
}
427+
int total_send =
428+
std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0);
425429
sendlist_tensor->Reshape({total_send});
426430

427431
/**
@@ -433,7 +437,7 @@ void DeepPotPD::compute(ENERGYVTYPE& ener,
433437
pointer_addresses.reserve(nswap);
434438
for (int iswap = 0; iswap < nswap; ++iswap) {
435439
std::intptr_t addr =
436-
reinterpret_cast<std::intptr_t>(comm_sendlist_[iswap].data());
440+
reinterpret_cast<std::intptr_t>(lmp_list.sendlist[iswap]);
437441
pointer_addresses.push_back(addr);
438442
}
439443
sendlist_tensor->CopyFromCpu(pointer_addresses.data());

source/api_cc/src/DeepPotPT.cc

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -212,20 +212,16 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
212212
nlist_data.padding();
213213
if (do_message_passing) {
214214
int nswap = lmp_list.nswap;
215-
select_real_atoms_sendlist(comm_sendlist_, comm_sendnum_, comm_recvnum_,
216-
lmp_list, fwd_map);
217-
comm_sendlist_ptrs_.resize(nswap);
218-
for (int s = 0; s < nswap; ++s) {
219-
comm_sendlist_ptrs_[s] = comm_sendlist_[s].data();
220-
}
221215
torch::Tensor sendproc_tensor =
222216
torch::from_blob(lmp_list.sendproc, {nswap}, int32_option);
223217
torch::Tensor recvproc_tensor =
224218
torch::from_blob(lmp_list.recvproc, {nswap}, int32_option);
225-
torch::Tensor sendnum_tensor =
226-
torch::from_blob(comm_sendnum_.data(), {nswap}, int32_option);
219+
torch::Tensor firstrecv_tensor =
220+
torch::from_blob(lmp_list.firstrecv, {nswap}, int32_option);
227221
torch::Tensor recvnum_tensor =
228-
torch::from_blob(comm_recvnum_.data(), {nswap}, int32_option);
222+
torch::from_blob(lmp_list.recvnum, {nswap}, int32_option);
223+
torch::Tensor sendnum_tensor =
224+
torch::from_blob(lmp_list.sendnum, {nswap}, int32_option);
229225
torch::Tensor communicator_tensor;
230226
if (lmp_list.world == 0) {
231227
communicator_tensor = torch::empty({1}, torch::kInt64);
@@ -234,9 +230,11 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
234230
const_cast<void*>(lmp_list.world), {1}, torch::kInt64);
235231
}
236232

237-
int ptr_len = nswap * static_cast<int>(sizeof(int*) / sizeof(int32_t));
238-
torch::Tensor sendlist_tensor = torch::from_blob(
239-
comm_sendlist_ptrs_.data(), {std::max(ptr_len, 1)}, int32_option);
233+
torch::Tensor nswap_tensor = torch::tensor(nswap, int32_option);
234+
int total_send =
235+
std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0);
236+
torch::Tensor sendlist_tensor =
237+
torch::from_blob(lmp_list.sendlist, {total_send}, int32_option);
240238
comm_dict.insert_or_assign("send_list", sendlist_tensor);
241239
comm_dict.insert_or_assign("send_proc", sendproc_tensor);
242240
comm_dict.insert_or_assign("recv_proc", recvproc_tensor);

source/api_cc/src/DeepSpinPT.cc

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -182,30 +182,28 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
182182
nlist_data.padding();
183183
if (do_message_passing) {
184184
int nswap = lmp_list.nswap;
185-
select_real_atoms_sendlist(comm_sendlist_, comm_sendnum_, comm_recvnum_,
186-
lmp_list, fwd_map);
187-
comm_sendlist_ptrs_.resize(nswap);
188-
for (int s = 0; s < nswap; ++s) {
189-
comm_sendlist_ptrs_[s] = comm_sendlist_[s].data();
190-
}
191185
torch::Tensor sendproc_tensor =
192186
torch::from_blob(lmp_list.sendproc, {nswap}, int32_option);
193187
torch::Tensor recvproc_tensor =
194188
torch::from_blob(lmp_list.recvproc, {nswap}, int32_option);
195-
torch::Tensor sendnum_tensor =
196-
torch::from_blob(comm_sendnum_.data(), {nswap}, int32_option);
189+
torch::Tensor firstrecv_tensor =
190+
torch::from_blob(lmp_list.firstrecv, {nswap}, int32_option);
197191
torch::Tensor recvnum_tensor =
198-
torch::from_blob(comm_recvnum_.data(), {nswap}, int32_option);
192+
torch::from_blob(lmp_list.recvnum, {nswap}, int32_option);
193+
torch::Tensor sendnum_tensor =
194+
torch::from_blob(lmp_list.sendnum, {nswap}, int32_option);
199195
torch::Tensor communicator_tensor;
200196
if (lmp_list.world == 0) {
201197
communicator_tensor = torch::empty({1}, torch::kInt64);
202198
} else {
203199
communicator_tensor = torch::from_blob(
204200
const_cast<void*>(lmp_list.world), {1}, torch::kInt64);
205201
}
206-
int ptr_len = nswap * static_cast<int>(sizeof(int*) / sizeof(int32_t));
207-
torch::Tensor sendlist_tensor = torch::from_blob(
208-
comm_sendlist_ptrs_.data(), {std::max(ptr_len, 1)}, int32_option);
202+
torch::Tensor nswap_tensor = torch::tensor(nswap, int32_option);
203+
int total_send =
204+
std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0);
205+
torch::Tensor sendlist_tensor =
206+
torch::from_blob(lmp_list.sendlist, {total_send}, int32_option);
209207
torch::Tensor has_spin = torch::tensor({1}, int32_option);
210208
comm_dict.insert_or_assign("send_list", sendlist_tensor);
211209
comm_dict.insert_or_assign("send_proc", sendproc_tensor);

source/api_cc/src/common.cc

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -232,32 +232,6 @@ template void deepmd::select_real_atoms_coord<float>(
232232
const int& nall,
233233
const bool aparam_nall);
234234

235-
void deepmd::select_real_atoms_sendlist(
236-
std::vector<std::vector<int>>& new_sendlist,
237-
std::vector<int>& new_sendnum,
238-
std::vector<int>& new_recvnum,
239-
const InputNlist& inlist,
240-
const std::vector<int>& fwd_map) {
241-
int nswap = inlist.nswap;
242-
new_sendlist.resize(nswap);
243-
new_sendnum.resize(nswap);
244-
new_recvnum.resize(nswap);
245-
for (int s = 0; s < nswap; ++s) {
246-
new_sendlist[s].clear();
247-
int orig_num = inlist.sendnum[s];
248-
new_sendlist[s].reserve(orig_num);
249-
for (int i = 0; i < orig_num; ++i) {
250-
int idx = inlist.sendlist[s][i];
251-
int mapped = fwd_map[idx];
252-
if (mapped >= 0) {
253-
new_sendlist[s].push_back(mapped);
254-
}
255-
}
256-
new_sendnum[s] = new_sendlist[s].size();
257-
new_recvnum[s] = new_sendlist[s].size();
258-
}
259-
}
260-
261235
void deepmd::NeighborListData::copy_from_nlist(const InputNlist& inlist,
262236
const int natoms) {
263237
int inum = natoms >= 0 ? natoms : inlist.inum;

0 commit comments

Comments
 (0)