Skip to content

Commit 7e648f0

Browse files
iProzdpre-commit-ci[bot]wanghan-iapcm
authored
fix(c++): fix NULL type in custom op (#4889)
Replaces usage of lmp_list send/recv arrays with new vectors that map indices using fwd_map and synchronize counts via MPI. Updates tensor construction to use these new vectors, improving correctness and flexibility in distributed communication. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Message-passing now remaps send lists to actual atoms before building communication tensors, improving consistency for distributed runs. * **Bug Fixes** * Invalid or out-of-range send indices are filtered and counts updated to prevent communication mismatches and related errors. * **Tests** * Added a test covering a type-map scenario with a NULL mapping to exercise the updated handling. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com>
1 parent 51fbd76 commit 7e648f0

15 files changed

Lines changed: 411 additions & 76 deletions

source/api_cc/include/DeepPotPD.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,10 @@ class DeepPotPD : public DeepPotBackend {
392392
bool gpu_enabled;
393393
std::unique_ptr<paddle_infer::Tensor> firstneigh_tensor;
394394
std::unique_ptr<paddle_infer::Tensor> mapping_tensor;
395+
std::vector<std::vector<int>> remapped_sendlist;
396+
std::vector<int*> remapped_sendlist_ptrs;
397+
std::vector<int> remapped_sendnum;
398+
std::vector<int> remapped_recvnum;
395399
};
396400

397401
} // namespace deepmd

source/api_cc/include/DeepPotPT.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,10 @@ class DeepPotPT : public DeepPotBackend {
340340
at::Tensor firstneigh_tensor;
341341
c10::optional<torch::Tensor> mapping_tensor;
342342
torch::Dict<std::string, torch::Tensor> comm_dict;
343+
std::vector<std::vector<int>> remapped_sendlist;
344+
std::vector<int*> remapped_sendlist_ptrs;
345+
std::vector<int> remapped_sendnum;
346+
std::vector<int> remapped_recvnum;
343347
bool profiler_enabled{false};
344348
std::string profiler_file;
345349
/**

source/api_cc/include/DeepSpinPT.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,10 @@ class DeepSpinPT : public DeepSpinBackend {
262262
at::Tensor firstneigh_tensor;
263263
c10::optional<torch::Tensor> mapping_tensor;
264264
torch::Dict<std::string, torch::Tensor> comm_dict;
265+
std::vector<std::vector<int>> remapped_sendlist;
266+
std::vector<int*> remapped_sendlist_ptrs;
267+
std::vector<int> remapped_sendnum;
268+
std::vector<int> remapped_recvnum;
265269
/**
266270
* @brief Translate PyTorch exceptions to the DeePMD-kit exception.
267271
* @param[in] f The function to run.

source/api_cc/include/common.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,27 @@ void select_map_inv(typename std::vector<VT>::iterator out,
153153
const std::vector<int>& fwd_map,
154154
const int& stride);
155155

156+
/**
157+
* @brief Remap communication sendlist for message passing with NULL-type atoms.
158+
*
159+
* When NULL-type (virtual) atoms are present, the original LAMMPS sendlist
160+
* contains indices referring to virtual atoms that have been filtered out by
161+
* select_real_atoms_coord. This function remaps those indices through fwd_map
162+
* and independently recomputes recvnum using firstrecv.
163+
*
164+
* @param[out] new_sendlist Remapped send lists per swap (vector of vectors).
165+
* @param[out] new_sendnum Number of atoms to send per swap after remapping.
166+
* @param[out] new_recvnum Number of atoms to receive per swap after remapping.
167+
* @param[in] lmp_list The LAMMPS neighbor list containing communication info.
168+
* @param[in] fwd_map Forward map from original atom index to real-atom index
169+
* (-1 for virtual/NULL atoms).
170+
*/
171+
void remap_comm_sendlist(std::vector<std::vector<int>>& new_sendlist,
172+
std::vector<int>& new_sendnum,
173+
std::vector<int>& new_recvnum,
174+
const InputNlist& lmp_list,
175+
const std::vector<int>& fwd_map);
176+
156177
/**
157178
* @brief Get the number of threads from the environment variable.
158179
* @details A warning will be thrown if environment variables are not set.

source/api_cc/include/commonPT.h

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// SPDX-License-Identifier: LGPL-3.0-or-later
2+
#pragma once
3+
4+
#ifdef BUILD_PYTORCH
5+
#include <torch/torch.h>
6+
7+
#include <cstdint>
8+
#include <vector>
9+
10+
#include "common.h"
11+
#include "neighbor_list.h"
12+
13+
namespace deepmd {
14+
15+
/**
16+
* @brief Build comm_dict tensors from sendlist/sendnum/recvnum buffers.
17+
*
18+
* This is the shared tensor-building logic for all PyTorch backends
19+
* (DeepPotPT, DeepSpinPT). Backend-specific entries (e.g. has_spin)
20+
* should be added by the caller after this function returns.
21+
*
22+
* @param[out] comm_dict The communication dictionary to populate.
23+
* @param[in] lmp_list The LAMMPS neighbor list (for sendproc/recvproc/world).
24+
* @param[in] sendlist Pointer array (int**) for each swap's send list.
25+
* @param[in] sendnum Number of send atoms per swap.
26+
* @param[in] recvnum Number of recv atoms per swap.
27+
*/
28+
inline void build_comm_dict(torch::Dict<std::string, torch::Tensor>& comm_dict,
29+
const InputNlist& lmp_list,
30+
int** sendlist,
31+
int* sendnum,
32+
int* recvnum) {
33+
int nswap = lmp_list.nswap;
34+
auto int32_option =
35+
torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt32);
36+
auto int64_option =
37+
torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64);
38+
39+
torch::Tensor sendlist_tensor =
40+
torch::from_blob(static_cast<void*>(sendlist), {nswap}, int64_option);
41+
torch::Tensor sendnum_tensor =
42+
torch::from_blob(sendnum, {nswap}, int32_option);
43+
torch::Tensor recvnum_tensor =
44+
torch::from_blob(recvnum, {nswap}, int32_option);
45+
torch::Tensor sendproc_tensor =
46+
torch::from_blob(lmp_list.sendproc, {nswap}, int32_option);
47+
torch::Tensor recvproc_tensor =
48+
torch::from_blob(lmp_list.recvproc, {nswap}, int32_option);
49+
50+
torch::Tensor communicator_tensor;
51+
static std::int64_t null_communicator = 0;
52+
if (lmp_list.world == nullptr) {
53+
communicator_tensor =
54+
torch::from_blob(&null_communicator, {1}, torch::kInt64);
55+
} else {
56+
communicator_tensor =
57+
torch::from_blob(const_cast<void*>(lmp_list.world), {1}, torch::kInt64);
58+
}
59+
60+
comm_dict.insert_or_assign("send_list", sendlist_tensor);
61+
comm_dict.insert_or_assign("send_proc", sendproc_tensor);
62+
comm_dict.insert_or_assign("recv_proc", recvproc_tensor);
63+
comm_dict.insert_or_assign("send_num", sendnum_tensor);
64+
comm_dict.insert_or_assign("recv_num", recvnum_tensor);
65+
comm_dict.insert_or_assign("communicator", communicator_tensor);
66+
}
67+
68+
/**
69+
* @brief Build comm_dict with sendlist remapping for virtual (NULL-type) atoms.
70+
*
71+
* Calls remap_comm_sendlist() to remap indices through fwd_map, then
72+
* build_comm_dict() to create tensors. Backend-specific entries (e.g.
73+
* has_spin) should be added by the caller after this function returns.
74+
*
75+
* @param[out] comm_dict The communication dictionary to populate.
76+
* @param[in] lmp_list The LAMMPS neighbor list containing communication info.
77+
* @param[in] fwd_map Map from original atom index to real-atom index (-1 for
78+
* virtual atoms).
79+
* @param[out] remapped_sendlist Storage for remapped send lists (kept alive for
80+
* tensor lifetime).
81+
* @param[out] remapped_sendlist_ptrs Pointer array into remapped_sendlist.
82+
* @param[out] remapped_sendnum Remapped send counts per swap.
83+
* @param[out] remapped_recvnum Remapped recv counts per swap.
84+
*/
85+
inline void build_comm_dict_with_virtual_atoms(
86+
torch::Dict<std::string, torch::Tensor>& comm_dict,
87+
const InputNlist& lmp_list,
88+
const std::vector<int>& fwd_map,
89+
std::vector<std::vector<int>>& remapped_sendlist,
90+
std::vector<int*>& remapped_sendlist_ptrs,
91+
std::vector<int>& remapped_sendnum,
92+
std::vector<int>& remapped_recvnum) {
93+
remap_comm_sendlist(remapped_sendlist, remapped_sendnum, remapped_recvnum,
94+
lmp_list, fwd_map);
95+
int nswap = lmp_list.nswap;
96+
remapped_sendlist_ptrs.resize(nswap);
97+
for (int s = 0; s < nswap; ++s) {
98+
remapped_sendlist_ptrs[s] = remapped_sendlist[s].data();
99+
}
100+
build_comm_dict(comm_dict, lmp_list, remapped_sendlist_ptrs.data(),
101+
remapped_sendnum.data(), remapped_recvnum.data());
102+
}
103+
104+
} // namespace deepmd
105+
106+
#endif // BUILD_PYTORCH

source/api_cc/src/DeepPotPD.cc

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,8 @@ void DeepPotPD::compute(ENERGYVTYPE& ener,
378378
bkw_map, nall_real, nloc_real, coord, atype, aparam,
379379
nghost, ntypes, 1, daparam, nall, aparam_nall);
380380
int nloc = nall_real - nghost_real;
381+
// Detect whether any NULL-type atoms were filtered out.
382+
bool has_null_atoms = (nall_real < nall);
381383
int nframes = 1;
382384
std::vector<VALUETYPE> coord_wrapped = dcoord;
383385
auto coord_wrapped_Tensor = predictor_fl->GetInputHandle("coord");
@@ -391,6 +393,28 @@ void DeepPotPD::compute(ENERGYVTYPE& ener,
391393
nlist_data.shuffle_exclude_empty(fwd_map);
392394
nlist_data.padding();
393395
if (do_message_passing) {
396+
// Determine the actual sendlist/sendnum/recvnum to use.
397+
// When NULL-type atoms exist, remap sendlist indices through fwd_map.
398+
int** eff_sendlist;
399+
int* eff_sendnum;
400+
int* eff_recvnum;
401+
if (has_null_atoms) {
402+
remap_comm_sendlist(remapped_sendlist, remapped_sendnum,
403+
remapped_recvnum, lmp_list, fwd_map);
404+
int nswap = lmp_list.nswap;
405+
remapped_sendlist_ptrs.resize(nswap);
406+
for (int s = 0; s < nswap; ++s) {
407+
remapped_sendlist_ptrs[s] = remapped_sendlist[s].data();
408+
}
409+
eff_sendlist = remapped_sendlist_ptrs.data();
410+
eff_sendnum = remapped_sendnum.data();
411+
eff_recvnum = remapped_recvnum.data();
412+
} else {
413+
eff_sendlist = lmp_list.sendlist;
414+
eff_sendnum = lmp_list.sendnum;
415+
eff_recvnum = lmp_list.recvnum;
416+
}
417+
394418
auto sendproc_tensor = predictor_fl->GetInputHandle("send_proc");
395419
auto recvproc_tensor = predictor_fl->GetInputHandle("recv_proc");
396420
auto recvnum_tensor = predictor_fl->GetInputHandle("recv_num");
@@ -406,26 +430,18 @@ void DeepPotPD::compute(ENERGYVTYPE& ener,
406430
recvproc_tensor->CopyFromCpu(lmp_list.recvproc);
407431

408432
recvnum_tensor->Reshape({nswap});
409-
recvnum_tensor->CopyFromCpu(lmp_list.recvnum);
433+
recvnum_tensor->CopyFromCpu(eff_recvnum);
410434

411435
sendnum_tensor->Reshape({nswap});
412-
if (sizeof(lmp_list.sendnum[0]) != sizeof(int32_t)) {
413-
std::vector<int32_t> temp_data(nswap);
414-
for (int i = 0; i < nswap; i++) {
415-
temp_data[i] = static_cast<int32_t>(lmp_list.sendnum[i]);
416-
}
417-
sendnum_tensor->CopyFromCpu(temp_data.data());
418-
} else {
419-
sendnum_tensor->CopyFromCpu(lmp_list.sendnum);
420-
}
436+
sendnum_tensor->CopyFromCpu(eff_sendnum);
437+
421438
communicator_tensor->Reshape({1});
422439
if (lmp_list.world) {
423440
communicator_tensor->CopyFromCpu(static_cast<int*>(lmp_list.world));
424441
}
425442

426443
assert(sizeof(std::intptr_t) == 8);
427-
int total_send =
428-
std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0);
444+
int total_send = std::accumulate(eff_sendnum, eff_sendnum + nswap, 0);
429445
sendlist_tensor->Reshape({total_send});
430446

431447
/**
@@ -437,7 +453,7 @@ void DeepPotPD::compute(ENERGYVTYPE& ener,
437453
pointer_addresses.reserve(nswap);
438454
for (int iswap = 0; iswap < nswap; ++iswap) {
439455
std::intptr_t addr =
440-
reinterpret_cast<std::intptr_t>(lmp_list.sendlist[iswap]);
456+
reinterpret_cast<std::intptr_t>(eff_sendlist[iswap]);
441457
pointer_addresses.push_back(addr);
442458
}
443459
sendlist_tensor->CopyFromCpu(pointer_addresses.data());

source/api_cc/src/DeepPotPT.cc

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <cstdint>
99

1010
#include "common.h"
11+
#include "commonPT.h"
1112
#include "device.h"
1213
#include "errors.h"
1314

@@ -198,6 +199,8 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
198199
bkw_map, nall_real, nloc_real, coord, atype, aparam,
199200
nghost, ntypes, 1, daparam, nall, aparam_nall);
200201
int nloc = nall_real - nghost_real;
202+
// Detect whether any NULL-type atoms were filtered out.
203+
bool has_null_atoms = (nall_real < nall);
201204
int nframes = 1;
202205
std::vector<VALUETYPE> coord_wrapped = dcoord;
203206
at::Tensor coord_wrapped_Tensor =
@@ -211,36 +214,14 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
211214
nlist_data.shuffle_exclude_empty(fwd_map);
212215
nlist_data.padding();
213216
if (do_message_passing) {
214-
int nswap = lmp_list.nswap;
215-
torch::Tensor sendproc_tensor =
216-
torch::from_blob(lmp_list.sendproc, {nswap}, int32_option);
217-
torch::Tensor recvproc_tensor =
218-
torch::from_blob(lmp_list.recvproc, {nswap}, int32_option);
219-
torch::Tensor firstrecv_tensor =
220-
torch::from_blob(lmp_list.firstrecv, {nswap}, int32_option);
221-
torch::Tensor recvnum_tensor =
222-
torch::from_blob(lmp_list.recvnum, {nswap}, int32_option);
223-
torch::Tensor sendnum_tensor =
224-
torch::from_blob(lmp_list.sendnum, {nswap}, int32_option);
225-
torch::Tensor communicator_tensor;
226-
if (lmp_list.world == 0) {
227-
communicator_tensor = torch::empty({1}, torch::kInt64);
217+
if (has_null_atoms) {
218+
build_comm_dict_with_virtual_atoms(
219+
comm_dict, lmp_list, fwd_map, remapped_sendlist,
220+
remapped_sendlist_ptrs, remapped_sendnum, remapped_recvnum);
228221
} else {
229-
communicator_tensor = torch::from_blob(
230-
const_cast<void*>(lmp_list.world), {1}, torch::kInt64);
222+
build_comm_dict(comm_dict, lmp_list, lmp_list.sendlist,
223+
lmp_list.sendnum, lmp_list.recvnum);
231224
}
232-
233-
torch::Tensor nswap_tensor = torch::tensor(nswap, int32_option);
234-
int total_send =
235-
std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0);
236-
torch::Tensor sendlist_tensor =
237-
torch::from_blob(lmp_list.sendlist, {total_send}, int32_option);
238-
comm_dict.insert_or_assign("send_list", sendlist_tensor);
239-
comm_dict.insert_or_assign("send_proc", sendproc_tensor);
240-
comm_dict.insert_or_assign("recv_proc", recvproc_tensor);
241-
comm_dict.insert_or_assign("send_num", sendnum_tensor);
242-
comm_dict.insert_or_assign("recv_num", recvnum_tensor);
243-
comm_dict.insert_or_assign("communicator", communicator_tensor);
244225
}
245226
if (lmp_list.mapping) {
246227
std::vector<std::int64_t> mapping(nall_real);
@@ -272,7 +253,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
272253
options)
273254
.to(device);
274255
}
275-
c10::Dict<c10::IValue, c10::IValue> outputs =
256+
auto outputs =
276257
(do_message_passing)
277258
? module
278259
.run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor,

0 commit comments

Comments
 (0)