Skip to content

Commit 8fdc524

Browse files
committed
Use insert_or_assign for comm_dict tensor assignments
Replaces comm_dict.insert with comm_dict.insert_or_assign for all tensor assignments in DeepPotPT.cc and DeepSpinPT.cc. This ensures that existing keys are updated rather than causing errors or duplications, improving robustness when keys may already exist.
1 parent 106c973 commit 8fdc524

2 files changed

Lines changed: 11 additions & 11 deletions

File tree

source/api_cc/src/DeepPotPT.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,11 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
198198
torch::Tensor sendlist_tensor =
199199
torch::from_blob(lmp_list.sendlist, {total_send}, int32_option);
200200
comm_dict.insert_or_assign("send_list", sendlist_tensor);
201-
comm_dict.insert("send_proc", sendproc_tensor);
202-
comm_dict.insert("recv_proc", recvproc_tensor);
203-
comm_dict.insert("send_num", sendnum_tensor);
204-
comm_dict.insert("recv_num", recvnum_tensor);
205-
comm_dict.insert("communicator", communicator_tensor);
201+
comm_dict.insert_or_assign("send_proc", sendproc_tensor);
202+
comm_dict.insert_or_assign("recv_proc", recvproc_tensor);
203+
comm_dict.insert_or_assign("send_num", sendnum_tensor);
204+
comm_dict.insert_or_assign("recv_num", recvnum_tensor);
205+
comm_dict.insert_or_assign("communicator", communicator_tensor);
206206
}
207207
if (lmp_list.mapping) {
208208
std::vector<std::int64_t> mapping(nall_real);

source/api_cc/src/DeepSpinPT.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -206,12 +206,12 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
206206
torch::from_blob(lmp_list.sendlist, {total_send}, int32_option);
207207
torch::Tensor has_spin = torch::tensor({1}, int32_option);
208208
comm_dict.insert_or_assign("send_list", sendlist_tensor);
209-
comm_dict.insert("send_proc", sendproc_tensor);
210-
comm_dict.insert("recv_proc", recvproc_tensor);
211-
comm_dict.insert("send_num", sendnum_tensor);
212-
comm_dict.insert("recv_num", recvnum_tensor);
213-
comm_dict.insert("communicator", communicator_tensor);
214-
comm_dict.insert("has_spin", has_spin);
209+
comm_dict.insert_or_assign("send_proc", sendproc_tensor);
210+
comm_dict.insert_or_assign("recv_proc", recvproc_tensor);
211+
comm_dict.insert_or_assign("send_num", sendnum_tensor);
212+
comm_dict.insert_or_assign("recv_num", recvnum_tensor);
213+
comm_dict.insert_or_assign("communicator", communicator_tensor);
214+
comm_dict.insert_or_assign("has_spin", has_spin);
215215
}
216216
}
217217
at::Tensor firstneigh = createNlistTensor2(nlist_data.jlist);

0 commit comments

Comments
 (0)