Skip to content

Commit 68f0d21

Browse files
authored
fix(cc): use insert_or_assign instead of insert (deepmodeling#4844)
dict.insert will not replace old value of the key value while lammps reallocate will keep old value in old address. This may cause dict always use the initial value of some keys. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved handling of communication settings to ensure that updates to existing entries are correctly applied, preventing potential issues with outdated or duplicate information during tensor communication. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 2b32af5 commit 68f0d21

2 files changed

Lines changed: 13 additions & 13 deletions

File tree

source/api_cc/src/DeepPotPT.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,12 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
197197
std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0);
198198
torch::Tensor sendlist_tensor =
199199
torch::from_blob(lmp_list.sendlist, {total_send}, int32_option);
200-
comm_dict.insert("send_list", sendlist_tensor);
201-
comm_dict.insert("send_proc", sendproc_tensor);
202-
comm_dict.insert("recv_proc", recvproc_tensor);
203-
comm_dict.insert("send_num", sendnum_tensor);
204-
comm_dict.insert("recv_num", recvnum_tensor);
205-
comm_dict.insert("communicator", communicator_tensor);
200+
comm_dict.insert_or_assign("send_list", sendlist_tensor);
201+
comm_dict.insert_or_assign("send_proc", sendproc_tensor);
202+
comm_dict.insert_or_assign("recv_proc", recvproc_tensor);
203+
comm_dict.insert_or_assign("send_num", sendnum_tensor);
204+
comm_dict.insert_or_assign("recv_num", recvnum_tensor);
205+
comm_dict.insert_or_assign("communicator", communicator_tensor);
206206
}
207207
if (lmp_list.mapping) {
208208
std::vector<std::int64_t> mapping(nall_real);

source/api_cc/src/DeepSpinPT.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -205,13 +205,13 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
205205
torch::Tensor sendlist_tensor =
206206
torch::from_blob(lmp_list.sendlist, {total_send}, int32_option);
207207
torch::Tensor has_spin = torch::tensor({1}, int32_option);
208-
comm_dict.insert("send_list", sendlist_tensor);
209-
comm_dict.insert("send_proc", sendproc_tensor);
210-
comm_dict.insert("recv_proc", recvproc_tensor);
211-
comm_dict.insert("send_num", sendnum_tensor);
212-
comm_dict.insert("recv_num", recvnum_tensor);
213-
comm_dict.insert("communicator", communicator_tensor);
214-
comm_dict.insert("has_spin", has_spin);
208+
comm_dict.insert_or_assign("send_list", sendlist_tensor);
209+
comm_dict.insert_or_assign("send_proc", sendproc_tensor);
210+
comm_dict.insert_or_assign("recv_proc", recvproc_tensor);
211+
comm_dict.insert_or_assign("send_num", sendnum_tensor);
212+
comm_dict.insert_or_assign("recv_num", recvnum_tensor);
213+
comm_dict.insert_or_assign("communicator", communicator_tensor);
214+
comm_dict.insert_or_assign("has_spin", has_spin);
215215
}
216216
}
217217
at::Tensor firstneigh = createNlistTensor2(nlist_data.jlist);

0 commit comments

Comments
 (0)