77#include < cstdint>
88
99#include " common.h"
10+ #include " commonPT.h"
1011#include " device.h"
1112#include " errors.h"
1213
@@ -163,6 +164,8 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
163164 bkw_map, nall_real, nloc_real, coord, atype, aparam,
164165 nghost, ntypes, 1 , daparam, nall, aparam_nall);
165166 int nloc = nall_real - nghost_real;
167+ // Detect whether any NULL-type atoms were filtered out.
168+ bool has_null_atoms = (nall_real < nall);
166169 int nframes = 1 ;
167170 std::vector<VALUETYPE> coord_wrapped = dcoord;
168171 at::Tensor coord_wrapped_Tensor =
@@ -175,43 +178,32 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
175178 std::vector<std::int64_t > atype_64 (datype.begin (), datype.end ());
176179 at::Tensor atype_Tensor =
177180 torch::from_blob (atype_64.data (), {1 , nall_real}, int_option).to (device);
178- c10::optional<torch::Tensor> mapping_tensor;
179181 if (ago == 0 ) {
180182 nlist_data.copy_from_nlist (lmp_list, nall - nghost);
181183 nlist_data.shuffle_exclude_empty (fwd_map);
182184 nlist_data.padding ();
183185 if (do_message_passing) {
184- int nswap = lmp_list.nswap ;
185- torch::Tensor sendproc_tensor =
186- torch::from_blob (lmp_list.sendproc , {nswap}, int32_option);
187- torch::Tensor recvproc_tensor =
188- torch::from_blob (lmp_list.recvproc , {nswap}, int32_option);
189- torch::Tensor firstrecv_tensor =
190- torch::from_blob (lmp_list.firstrecv , {nswap}, int32_option);
191- torch::Tensor recvnum_tensor =
192- torch::from_blob (lmp_list.recvnum , {nswap}, int32_option);
193- torch::Tensor sendnum_tensor =
194- torch::from_blob (lmp_list.sendnum , {nswap}, int32_option);
195- torch::Tensor communicator_tensor;
196- if (lmp_list.world == 0 ) {
197- communicator_tensor = torch::empty ({1 }, torch::kInt64 );
186+ if (has_null_atoms) {
187+ build_comm_dict_with_virtual_atoms (
188+ comm_dict, lmp_list, fwd_map, remapped_sendlist,
189+ remapped_sendlist_ptrs, remapped_sendnum, remapped_recvnum);
198190 } else {
199- communicator_tensor = torch::from_blob (
200- const_cast < void *>( lmp_list.world ), { 1 }, torch:: kInt64 );
191+ build_comm_dict (comm_dict, lmp_list, lmp_list. sendlist ,
192+ lmp_list.sendnum , lmp_list. recvnum );
201193 }
202- torch::Tensor nswap_tensor = torch::tensor (nswap, int32_option);
203- int total_send =
204- std::accumulate (lmp_list. sendnum , lmp_list. sendnum + nswap, 0 );
205- torch::Tensor sendlist_tensor =
206- torch::from_blob (lmp_list. sendlist , {total_send}, int32_option);
207- torch::Tensor has_spin = torch::tensor ({ 1 }, int32_option);
208- comm_dict. insert_or_assign ( " send_list " , sendlist_tensor );
209- comm_dict. insert_or_assign ( " send_proc " , sendproc_tensor);
210- comm_dict. insert_or_assign ( " recv_proc " , recvproc_tensor) ;
211- comm_dict. insert_or_assign ( " send_num " , sendnum_tensor);
212- comm_dict. insert_or_assign ( " recv_num " , recvnum_tensor);
213- comm_dict. insert_or_assign ( " communicator " , communicator_tensor);
214- comm_dict. insert_or_assign ( " has_spin " , has_spin );
194+ // DeepSpin-specific: signal spin model to the Python side
195+ auto int32_option =
196+ torch::TensorOptions (). device (torch:: kCPU ). dtype (torch:: kInt32 );
197+ comm_dict. insert_or_assign ( " has_spin " , torch::tensor ({ 1 }, int32_option));
198+ }
199+ if (lmp_list. mapping ) {
200+ std::vector<std:: int64_t > mapping (nall_real );
201+ for ( size_t ii = 0 ; ii < nall_real; ii++) {
202+ mapping[ii] = lmp_list. mapping [fwd_map[ii]] ;
203+ }
204+ mapping_tensor =
205+ torch::from_blob (mapping. data (), { 1 , nall_real}, int_option)
206+ . to (device );
215207 }
216208 }
217209 at::Tensor firstneigh = createNlistTensor2 (nlist_data.jlist );
@@ -234,7 +226,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
234226 options)
235227 .to (device);
236228 }
237- c10::Dict<c10::IValue, c10::IValue> outputs =
229+ auto outputs =
238230 (do_message_passing)
239231 ? module
240232 .run_method (" forward_lower" , coord_wrapped_Tensor, atype_Tensor,
0 commit comments