@@ -199,15 +199,15 @@ void DeepTensorPT::get_type_map(std::string& type_map) {
199199}
200200
201201template <typename VALUETYPE >
202- void DeepTensorPT::compute_inner (std::vector<VALUETYPE >& global_tensor,
203- std::vector<VALUETYPE >& force,
204- std::vector<VALUETYPE >& virial,
205- std::vector<VALUETYPE >& atom_tensor,
206- std::vector<VALUETYPE >& atom_virial,
207- const std::vector<VALUETYPE >& coord,
208- const std::vector<int >& atype,
209- const std::vector<VALUETYPE >& box,
210- const bool request_deriv) {
202+ void DeepTensorPT::compute (std::vector<VALUETYPE >& global_tensor,
203+ std::vector<VALUETYPE >& force,
204+ std::vector<VALUETYPE >& virial,
205+ std::vector<VALUETYPE >& atom_tensor,
206+ std::vector<VALUETYPE >& atom_virial,
207+ const std::vector<VALUETYPE >& coord,
208+ const std::vector<int >& atype,
209+ const std::vector<VALUETYPE >& box,
210+ const bool request_deriv) {
211211 // This is the simpler version without neighbor list optimization
212212 // Use a dummy neighbor list and call the full version
213213 deepmd::InputNlist dummy_nlist;
@@ -219,22 +219,22 @@ void DeepTensorPT::compute_inner(std::vector<VALUETYPE>& global_tensor,
219219 dummy_nlist.firstneigh .resize (dummy_nlist.inum );
220220
221221 // Call the neighbor list version with nghost=0 and empty neighbor list
222- compute_inner (global_tensor, force, virial, atom_tensor, atom_virial, coord,
223- atype, box, 0 , dummy_nlist, request_deriv);
222+ compute (global_tensor, force, virial, atom_tensor, atom_virial, coord, atype ,
223+ box, 0 , dummy_nlist, request_deriv);
224224}
225225
226226template <typename VALUETYPE >
227- void DeepTensorPT::compute_inner (std::vector<VALUETYPE >& global_tensor,
228- std::vector<VALUETYPE >& force,
229- std::vector<VALUETYPE >& virial,
230- std::vector<VALUETYPE >& atom_tensor,
231- std::vector<VALUETYPE >& atom_virial,
232- const std::vector<VALUETYPE >& coord,
233- const std::vector<int >& atype,
234- const std::vector<VALUETYPE >& box,
235- const int nghost,
236- const InputNlist& lmp_list,
237- const bool request_deriv) {
227+ void DeepTensorPT::compute (std::vector<VALUETYPE >& global_tensor,
228+ std::vector<VALUETYPE >& force,
229+ std::vector<VALUETYPE >& virial,
230+ std::vector<VALUETYPE >& atom_tensor,
231+ std::vector<VALUETYPE >& atom_virial,
232+ const std::vector<VALUETYPE >& coord,
233+ const std::vector<int >& atype,
234+ const std::vector<VALUETYPE >& box,
235+ const int nghost,
236+ const InputNlist& lmp_list,
237+ const bool request_deriv) {
238238 torch::Device device (torch::kCUDA , gpu_id);
239239 if (!gpu_enabled) {
240240 device = torch::Device (torch::kCPU );
@@ -397,8 +397,8 @@ void DeepTensorPT::computew(std::vector<double>& global_tensor,
397397 const std::vector<double >& box,
398398 const bool request_deriv) {
399399 translate_error ([&] {
400- compute_inner (global_tensor, force, virial, atom_tensor, atom_virial, coord,
401- atype, box, request_deriv);
400+ compute (global_tensor, force, virial, atom_tensor, atom_virial, coord,
401+ atype, box, request_deriv);
402402 });
403403}
404404
@@ -412,8 +412,8 @@ void DeepTensorPT::computew(std::vector<float>& global_tensor,
412412 const std::vector<float >& box,
413413 const bool request_deriv) {
414414 translate_error ([&] {
415- compute_inner (global_tensor, force, virial, atom_tensor, atom_virial, coord,
416- atype, box, request_deriv);
415+ compute (global_tensor, force, virial, atom_tensor, atom_virial, coord,
416+ atype, box, request_deriv);
417417 });
418418}
419419
@@ -429,8 +429,8 @@ void DeepTensorPT::computew(std::vector<double>& global_tensor,
429429 const InputNlist& inlist,
430430 const bool request_deriv) {
431431 translate_error ([&] {
432- compute_inner (global_tensor, force, virial, atom_tensor, atom_virial, coord,
433- atype, box, nghost, inlist, request_deriv);
432+ compute (global_tensor, force, virial, atom_tensor, atom_virial, coord,
433+ atype, box, nghost, inlist, request_deriv);
434434 });
435435}
436436
@@ -446,8 +446,8 @@ void DeepTensorPT::computew(std::vector<float>& global_tensor,
446446 const InputNlist& inlist,
447447 const bool request_deriv) {
448448 translate_error ([&] {
449- compute_inner (global_tensor, force, virial, atom_tensor, atom_virial, coord,
450- atype, box, nghost, inlist, request_deriv);
449+ compute (global_tensor, force, virial, atom_tensor, atom_virial, coord,
450+ atype, box, nghost, inlist, request_deriv);
451451 });
452452}
453453
0 commit comments