Skip to content

Commit 5875158

Browse files
committed
update fix, ref to deepmodeling#5268
1 parent c8e4be1 commit 5875158

12 files changed

Lines changed: 225 additions & 65 deletions

File tree

source/api_cc/include/DeepPotPT.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,10 @@ class DeepPotPT : public DeepPotBackend {
340340
at::Tensor firstneigh_tensor;
341341
c10::optional<torch::Tensor> mapping_tensor;
342342
torch::Dict<std::string, torch::Tensor> comm_dict;
343+
std::vector<std::vector<int>> remapped_sendlist;
344+
std::vector<int*> remapped_sendlist_ptrs;
345+
std::vector<int> remapped_sendnum;
346+
std::vector<int> remapped_recvnum;
343347
bool profiler_enabled{false};
344348
std::string profiler_file;
345349
/**

source/api_cc/include/DeepSpinPT.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,10 @@ class DeepSpinPT : public DeepSpinBackend {
262262
at::Tensor firstneigh_tensor;
263263
c10::optional<torch::Tensor> mapping_tensor;
264264
torch::Dict<std::string, torch::Tensor> comm_dict;
265+
std::vector<std::vector<int>> remapped_sendlist;
266+
std::vector<int*> remapped_sendlist_ptrs;
267+
std::vector<int> remapped_sendnum;
268+
std::vector<int> remapped_recvnum;
265269
/**
266270
* @brief Translate PyTorch exceptions to the DeePMD-kit exception.
267271
* @param[in] f The function to run.

source/api_cc/include/common.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,27 @@ void select_map_inv(typename std::vector<VT>::iterator out,
153153
const std::vector<int>& fwd_map,
154154
const int& stride);
155155

156+
/**
157+
* @brief Remap communication sendlist for message passing with NULL-type atoms.
158+
*
159+
* When NULL-type (virtual) atoms are present, the original LAMMPS sendlist
160+
* contains indices referring to virtual atoms that have been filtered out by
161+
* select_real_atoms_coord. This function remaps those indices through fwd_map
162+
* and independently recomputes recvnum using firstrecv.
163+
*
164+
* @param[out] new_sendlist Remapped send lists per swap (vector of vectors).
165+
* @param[out] new_sendnum Number of atoms to send per swap after remapping.
166+
* @param[out] new_recvnum Number of atoms to receive per swap after remapping.
167+
* @param[in] lmp_list The LAMMPS neighbor list containing communication info.
168+
* @param[in] fwd_map Forward map from original atom index to real-atom index
169+
* (-1 for virtual/NULL atoms).
170+
*/
171+
void remap_comm_sendlist(std::vector<std::vector<int>>& new_sendlist,
172+
std::vector<int>& new_sendnum,
173+
std::vector<int>& new_recvnum,
174+
const InputNlist& lmp_list,
175+
const std::vector<int>& fwd_map);
176+
156177
/**
157178
* @brief Get the number of threads from the environment variable.
158179
* @details A warning will be thrown if environment variables are not set.

source/api_cc/include/commonPT.h

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// SPDX-License-Identifier: LGPL-3.0-or-later
2+
#pragma once
3+
4+
#ifdef BUILD_PYTORCH
5+
#include <torch/torch.h>
6+
7+
#include <cstdint>
8+
#include <vector>
9+
10+
#include "common.h"
11+
#include "neighbor_list.h"
12+
13+
namespace deepmd {
14+
15+
/**
16+
* @brief Build comm_dict tensors from sendlist/sendnum/recvnum buffers.
17+
*
18+
* This is the shared tensor-building logic for all PyTorch backends
19+
* (DeepPotPT, DeepSpinPT). Backend-specific entries (e.g. has_spin)
20+
* should be added by the caller after this function returns.
21+
*
22+
* @param[out] comm_dict The communication dictionary to populate.
23+
* @param[in] lmp_list The LAMMPS neighbor list (for sendproc/recvproc/world).
24+
* @param[in] sendlist Pointer array (int**) for each swap's send list.
25+
* @param[in] sendnum Number of send atoms per swap.
26+
* @param[in] recvnum Number of recv atoms per swap.
27+
*/
28+
inline void build_comm_dict(torch::Dict<std::string, torch::Tensor>& comm_dict,
29+
const InputNlist& lmp_list,
30+
int** sendlist,
31+
int* sendnum,
32+
int* recvnum) {
33+
int nswap = lmp_list.nswap;
34+
auto int32_option =
35+
torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt32);
36+
auto int64_option =
37+
torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64);
38+
39+
torch::Tensor sendlist_tensor =
40+
torch::from_blob(static_cast<void*>(sendlist), {nswap}, int64_option);
41+
torch::Tensor sendnum_tensor =
42+
torch::from_blob(sendnum, {nswap}, int32_option);
43+
torch::Tensor recvnum_tensor =
44+
torch::from_blob(recvnum, {nswap}, int32_option);
45+
torch::Tensor sendproc_tensor =
46+
torch::from_blob(lmp_list.sendproc, {nswap}, int32_option);
47+
torch::Tensor recvproc_tensor =
48+
torch::from_blob(lmp_list.recvproc, {nswap}, int32_option);
49+
50+
torch::Tensor communicator_tensor;
51+
static std::int64_t null_communicator = 0;
52+
if (lmp_list.world == nullptr) {
53+
communicator_tensor =
54+
torch::from_blob(&null_communicator, {1}, torch::kInt64);
55+
} else {
56+
communicator_tensor =
57+
torch::from_blob(const_cast<void*>(lmp_list.world), {1}, torch::kInt64);
58+
}
59+
60+
comm_dict.insert_or_assign("send_list", sendlist_tensor);
61+
comm_dict.insert_or_assign("send_proc", sendproc_tensor);
62+
comm_dict.insert_or_assign("recv_proc", recvproc_tensor);
63+
comm_dict.insert_or_assign("send_num", sendnum_tensor);
64+
comm_dict.insert_or_assign("recv_num", recvnum_tensor);
65+
comm_dict.insert_or_assign("communicator", communicator_tensor);
66+
}
67+
68+
/**
69+
* @brief Build comm_dict with sendlist remapping for virtual (NULL-type) atoms.
70+
*
71+
* Calls remap_comm_sendlist() to remap indices through fwd_map, then
72+
* build_comm_dict() to create tensors. Backend-specific entries (e.g.
73+
* has_spin) should be added by the caller after this function returns.
74+
*
75+
* @param[out] comm_dict The communication dictionary to populate.
76+
* @param[in] lmp_list The LAMMPS neighbor list containing communication info.
77+
* @param[in] fwd_map Map from original atom index to real-atom index (-1 for
78+
* virtual atoms).
79+
* @param[out] remapped_sendlist Storage for remapped send lists (kept alive for
80+
* tensor lifetime).
81+
* @param[out] remapped_sendlist_ptrs Pointer array into remapped_sendlist.
82+
* @param[out] remapped_sendnum Remapped send counts per swap.
83+
* @param[out] remapped_recvnum Remapped recv counts per swap.
84+
*/
85+
inline void build_comm_dict_with_virtual_atoms(
86+
torch::Dict<std::string, torch::Tensor>& comm_dict,
87+
const InputNlist& lmp_list,
88+
const std::vector<int>& fwd_map,
89+
std::vector<std::vector<int>>& remapped_sendlist,
90+
std::vector<int*>& remapped_sendlist_ptrs,
91+
std::vector<int>& remapped_sendnum,
92+
std::vector<int>& remapped_recvnum) {
93+
remap_comm_sendlist(remapped_sendlist, remapped_sendnum, remapped_recvnum,
94+
lmp_list, fwd_map);
95+
int nswap = lmp_list.nswap;
96+
remapped_sendlist_ptrs.resize(nswap);
97+
for (int s = 0; s < nswap; ++s) {
98+
remapped_sendlist_ptrs[s] = remapped_sendlist[s].data();
99+
}
100+
build_comm_dict(comm_dict, lmp_list, remapped_sendlist_ptrs.data(),
101+
remapped_sendnum.data(), remapped_recvnum.data());
102+
}
103+
104+
} // namespace deepmd
105+
106+
#endif // BUILD_PYTORCH

source/api_cc/src/DeepPotPT.cc

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <cstdint>
99

1010
#include "common.h"
11+
#include "commonPT.h"
1112
#include "device.h"
1213
#include "errors.h"
1314

@@ -198,6 +199,8 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
198199
bkw_map, nall_real, nloc_real, coord, atype, aparam,
199200
nghost, ntypes, 1, daparam, nall, aparam_nall);
200201
int nloc = nall_real - nghost_real;
202+
// Detect whether any NULL-type atoms were filtered out.
203+
bool has_null_atoms = (nall_real < nall);
201204
int nframes = 1;
202205
std::vector<VALUETYPE> coord_wrapped = dcoord;
203206
at::Tensor coord_wrapped_Tensor =
@@ -211,36 +214,14 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
211214
nlist_data.shuffle_exclude_empty(fwd_map);
212215
nlist_data.padding();
213216
if (do_message_passing) {
214-
int nswap = lmp_list.nswap;
215-
torch::Tensor sendproc_tensor =
216-
torch::from_blob(lmp_list.sendproc, {nswap}, int32_option);
217-
torch::Tensor recvproc_tensor =
218-
torch::from_blob(lmp_list.recvproc, {nswap}, int32_option);
219-
torch::Tensor firstrecv_tensor =
220-
torch::from_blob(lmp_list.firstrecv, {nswap}, int32_option);
221-
torch::Tensor recvnum_tensor =
222-
torch::from_blob(lmp_list.recvnum, {nswap}, int32_option);
223-
torch::Tensor sendnum_tensor =
224-
torch::from_blob(lmp_list.sendnum, {nswap}, int32_option);
225-
torch::Tensor communicator_tensor;
226-
if (lmp_list.world == 0) {
227-
communicator_tensor = torch::empty({1}, torch::kInt64);
217+
if (has_null_atoms) {
218+
build_comm_dict_with_virtual_atoms(
219+
comm_dict, lmp_list, fwd_map, remapped_sendlist,
220+
remapped_sendlist_ptrs, remapped_sendnum, remapped_recvnum);
228221
} else {
229-
communicator_tensor = torch::from_blob(
230-
const_cast<void*>(lmp_list.world), {1}, torch::kInt64);
222+
build_comm_dict(comm_dict, lmp_list, lmp_list.sendlist,
223+
lmp_list.sendnum, lmp_list.recvnum);
231224
}
232-
233-
torch::Tensor nswap_tensor = torch::tensor(nswap, int32_option);
234-
int total_send =
235-
std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0);
236-
torch::Tensor sendlist_tensor =
237-
torch::from_blob(lmp_list.sendlist, {total_send}, int32_option);
238-
comm_dict.insert_or_assign("send_list", sendlist_tensor);
239-
comm_dict.insert_or_assign("send_proc", sendproc_tensor);
240-
comm_dict.insert_or_assign("recv_proc", recvproc_tensor);
241-
comm_dict.insert_or_assign("send_num", sendnum_tensor);
242-
comm_dict.insert_or_assign("recv_num", recvnum_tensor);
243-
comm_dict.insert_or_assign("communicator", communicator_tensor);
244225
}
245226
if (lmp_list.mapping) {
246227
std::vector<std::int64_t> mapping(nall_real);
@@ -272,7 +253,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
272253
options)
273254
.to(device);
274255
}
275-
c10::Dict<c10::IValue, c10::IValue> outputs =
256+
auto outputs =
276257
(do_message_passing)
277258
? module
278259
.run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor,

source/api_cc/src/DeepSpinPT.cc

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <cstdint>
88

99
#include "common.h"
10+
#include "commonPT.h"
1011
#include "device.h"
1112
#include "errors.h"
1213

@@ -163,6 +164,8 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
163164
bkw_map, nall_real, nloc_real, coord, atype, aparam,
164165
nghost, ntypes, 1, daparam, nall, aparam_nall);
165166
int nloc = nall_real - nghost_real;
167+
// Detect whether any NULL-type atoms were filtered out.
168+
bool has_null_atoms = (nall_real < nall);
166169
int nframes = 1;
167170
std::vector<VALUETYPE> coord_wrapped = dcoord;
168171
at::Tensor coord_wrapped_Tensor =
@@ -175,43 +178,32 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
175178
std::vector<std::int64_t> atype_64(datype.begin(), datype.end());
176179
at::Tensor atype_Tensor =
177180
torch::from_blob(atype_64.data(), {1, nall_real}, int_option).to(device);
178-
c10::optional<torch::Tensor> mapping_tensor;
179181
if (ago == 0) {
180182
nlist_data.copy_from_nlist(lmp_list, nall - nghost);
181183
nlist_data.shuffle_exclude_empty(fwd_map);
182184
nlist_data.padding();
183185
if (do_message_passing) {
184-
int nswap = lmp_list.nswap;
185-
torch::Tensor sendproc_tensor =
186-
torch::from_blob(lmp_list.sendproc, {nswap}, int32_option);
187-
torch::Tensor recvproc_tensor =
188-
torch::from_blob(lmp_list.recvproc, {nswap}, int32_option);
189-
torch::Tensor firstrecv_tensor =
190-
torch::from_blob(lmp_list.firstrecv, {nswap}, int32_option);
191-
torch::Tensor recvnum_tensor =
192-
torch::from_blob(lmp_list.recvnum, {nswap}, int32_option);
193-
torch::Tensor sendnum_tensor =
194-
torch::from_blob(lmp_list.sendnum, {nswap}, int32_option);
195-
torch::Tensor communicator_tensor;
196-
if (lmp_list.world == 0) {
197-
communicator_tensor = torch::empty({1}, torch::kInt64);
186+
if (has_null_atoms) {
187+
build_comm_dict_with_virtual_atoms(
188+
comm_dict, lmp_list, fwd_map, remapped_sendlist,
189+
remapped_sendlist_ptrs, remapped_sendnum, remapped_recvnum);
198190
} else {
199-
communicator_tensor = torch::from_blob(
200-
const_cast<void*>(lmp_list.world), {1}, torch::kInt64);
191+
build_comm_dict(comm_dict, lmp_list, lmp_list.sendlist,
192+
lmp_list.sendnum, lmp_list.recvnum);
201193
}
202-
torch::Tensor nswap_tensor = torch::tensor(nswap, int32_option);
203-
int total_send =
204-
std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0);
205-
torch::Tensor sendlist_tensor =
206-
torch::from_blob(lmp_list.sendlist, {total_send}, int32_option);
207-
torch::Tensor has_spin = torch::tensor({1}, int32_option);
208-
comm_dict.insert_or_assign("send_list", sendlist_tensor);
209-
comm_dict.insert_or_assign("send_proc", sendproc_tensor);
210-
comm_dict.insert_or_assign("recv_proc", recvproc_tensor);
211-
comm_dict.insert_or_assign("send_num", sendnum_tensor);
212-
comm_dict.insert_or_assign("recv_num", recvnum_tensor);
213-
comm_dict.insert_or_assign("communicator", communicator_tensor);
214-
comm_dict.insert_or_assign("has_spin", has_spin);
194+
// DeepSpin-specific: signal spin model to the Python side
195+
auto int32_option =
196+
torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt32);
197+
comm_dict.insert_or_assign("has_spin", torch::tensor({1}, int32_option));
198+
}
199+
if (lmp_list.mapping) {
200+
std::vector<std::int64_t> mapping(nall_real);
201+
for (size_t ii = 0; ii < nall_real; ii++) {
202+
mapping[ii] = lmp_list.mapping[fwd_map[ii]];
203+
}
204+
mapping_tensor =
205+
torch::from_blob(mapping.data(), {1, nall_real}, int_option)
206+
.to(device);
215207
}
216208
}
217209
at::Tensor firstneigh = createNlistTensor2(nlist_data.jlist);
@@ -234,7 +226,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
234226
options)
235227
.to(device);
236228
}
237-
c10::Dict<c10::IValue, c10::IValue> outputs =
229+
auto outputs =
238230
(do_message_passing)
239231
? module
240232
.run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor,

source/api_cc/src/common.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,45 @@ template void deepmd::select_real_atoms_coord<float>(
232232
const int& nall,
233233
const bool aparam_nall);
234234

235+
void deepmd::remap_comm_sendlist(std::vector<std::vector<int>>& new_sendlist,
236+
std::vector<int>& new_sendnum,
237+
std::vector<int>& new_recvnum,
238+
const InputNlist& lmp_list,
239+
const std::vector<int>& fwd_map) {
240+
int nswap = lmp_list.nswap;
241+
new_sendlist.resize(nswap);
242+
new_sendnum.resize(nswap);
243+
new_recvnum.resize(nswap);
244+
245+
for (int s = 0; s < nswap; ++s) {
246+
int orig_sendnum = lmp_list.sendnum[s];
247+
new_sendlist[s].clear();
248+
new_sendlist[s].reserve(orig_sendnum);
249+
250+
for (int k = 0; k < orig_sendnum; ++k) {
251+
int orig_idx = lmp_list.sendlist[s][k];
252+
int real_idx = fwd_map[orig_idx];
253+
if (real_idx >= 0) {
254+
new_sendlist[s].push_back(real_idx);
255+
}
256+
}
257+
new_sendnum[s] = static_cast<int>(new_sendlist[s].size());
258+
259+
// Independently compute recvnum using firstrecv range.
260+
// In MPI, sendnum and recvnum are independent per process.
261+
int firstrecv = lmp_list.firstrecv[s];
262+
int orig_recvnum = lmp_list.recvnum[s];
263+
int recv_count = 0;
264+
for (int k = 0; k < orig_recvnum; ++k) {
265+
int orig_idx = firstrecv + k;
266+
if (fwd_map[orig_idx] >= 0) {
267+
++recv_count;
268+
}
269+
}
270+
new_recvnum[s] = recv_count;
271+
}
272+
}
273+
235274
void deepmd::NeighborListData::copy_from_nlist(const InputNlist& inlist,
236275
const int natoms) {
237276
int inum = natoms >= 0 ? natoms : inlist.inum;

source/lmp/fix_dplr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ FixDPLR::FixDPLR(LAMMPS* lmp, int narg, char** arg)
185185
}
186186
}
187187
if (!found_element && "NULL" == type_name) {
188-
type_idx_map.push_back(type_map.size()); // ghost type
188+
type_idx_map.push_back(-1); // virtual atom
189189
found_element = true;
190190
}
191191
if (!found_element) {

source/lmp/pair_deepmd.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,7 @@ void PairDeepMD::coeff(int narg, char** arg) {
853853
}
854854
}
855855
if (!found_element && "NULL" == type_name) {
856-
type_idx_map.push_back(type_map.size()); // ghost type
856+
type_idx_map.push_back(-1); // virtual atom
857857
found_element = true;
858858
}
859859
if (!found_element) {

source/lmp/pair_deepspin.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,7 @@ void PairDeepSpin::coeff(int narg, char** arg) {
883883
}
884884
}
885885
if (!found_element && "NULL" == type_name) {
886-
type_idx_map.push_back(type_map.size()); // ghost type
886+
type_idx_map.push_back(-1); // virtual atom
887887
found_element = true;
888888
}
889889
if (!found_element) {

0 commit comments

Comments
 (0)