Skip to content

Commit 8ea0839

Browse files
Copilotnjzjz
andcommitted
fix(pt): follow DeepPotPT pattern by implementing separate compute methods and renaming compute_inner to compute
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent a75c823 commit 8ea0839

2 files changed

Lines changed: 50 additions & 50 deletions

File tree

source/api_cc/include/DeepTensorPT.h

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,15 @@ class DeepTensorPT : public DeepTensorBase {
5656
* tensor, including force and virial.
5757
**/
5858
template <typename VALUETYPE>
59-
void compute_inner(std::vector<VALUETYPE>& global_tensor,
60-
std::vector<VALUETYPE>& force,
61-
std::vector<VALUETYPE>& virial,
62-
std::vector<VALUETYPE>& atom_tensor,
63-
std::vector<VALUETYPE>& atom_virial,
64-
const std::vector<VALUETYPE>& coord,
65-
const std::vector<int>& atype,
66-
const std::vector<VALUETYPE>& box,
67-
const bool request_deriv);
59+
void compute(std::vector<VALUETYPE>& global_tensor,
60+
std::vector<VALUETYPE>& force,
61+
std::vector<VALUETYPE>& virial,
62+
std::vector<VALUETYPE>& atom_tensor,
63+
std::vector<VALUETYPE>& atom_virial,
64+
const std::vector<VALUETYPE>& coord,
65+
const std::vector<int>& atype,
66+
const std::vector<VALUETYPE>& box,
67+
const bool request_deriv);
6868
/**
6969
* @brief Evaluate the global tensor and component-wise force and virial.
7070
* @param[out] global_tensor The global tensor to evaluate.
@@ -86,17 +86,17 @@ class DeepTensorPT : public DeepTensorBase {
8686
* tensor, including force and virial.
8787
**/
8888
template <typename VALUETYPE>
89-
void compute_inner(std::vector<VALUETYPE>& global_tensor,
90-
std::vector<VALUETYPE>& force,
91-
std::vector<VALUETYPE>& virial,
92-
std::vector<VALUETYPE>& atom_tensor,
93-
std::vector<VALUETYPE>& atom_virial,
94-
const std::vector<VALUETYPE>& coord,
95-
const std::vector<int>& atype,
96-
const std::vector<VALUETYPE>& box,
97-
const int nghost,
98-
const InputNlist& inlist,
99-
const bool request_deriv);
89+
void compute(std::vector<VALUETYPE>& global_tensor,
90+
std::vector<VALUETYPE>& force,
91+
std::vector<VALUETYPE>& virial,
92+
std::vector<VALUETYPE>& atom_tensor,
93+
std::vector<VALUETYPE>& atom_virial,
94+
const std::vector<VALUETYPE>& coord,
95+
const std::vector<int>& atype,
96+
const std::vector<VALUETYPE>& box,
97+
const int nghost,
98+
const InputNlist& inlist,
99+
const bool request_deriv);
100100

101101
public:
102102
/**

source/api_cc/src/DeepTensorPT.cc

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -199,15 +199,15 @@ void DeepTensorPT::get_type_map(std::string& type_map) {
199199
}
200200

201201
template <typename VALUETYPE>
202-
void DeepTensorPT::compute_inner(std::vector<VALUETYPE>& global_tensor,
203-
std::vector<VALUETYPE>& force,
204-
std::vector<VALUETYPE>& virial,
205-
std::vector<VALUETYPE>& atom_tensor,
206-
std::vector<VALUETYPE>& atom_virial,
207-
const std::vector<VALUETYPE>& coord,
208-
const std::vector<int>& atype,
209-
const std::vector<VALUETYPE>& box,
210-
const bool request_deriv) {
202+
void DeepTensorPT::compute(std::vector<VALUETYPE>& global_tensor,
203+
std::vector<VALUETYPE>& force,
204+
std::vector<VALUETYPE>& virial,
205+
std::vector<VALUETYPE>& atom_tensor,
206+
std::vector<VALUETYPE>& atom_virial,
207+
const std::vector<VALUETYPE>& coord,
208+
const std::vector<int>& atype,
209+
const std::vector<VALUETYPE>& box,
210+
const bool request_deriv) {
211211
// This is the simpler version without neighbor list optimization
212212
// Use a dummy neighbor list and call the full version
213213
deepmd::InputNlist dummy_nlist;
@@ -219,22 +219,22 @@ void DeepTensorPT::compute_inner(std::vector<VALUETYPE>& global_tensor,
219219
dummy_nlist.firstneigh.resize(dummy_nlist.inum);
220220

221221
// Call the neighbor list version with nghost=0 and empty neighbor list
222-
compute_inner(global_tensor, force, virial, atom_tensor, atom_virial, coord,
223-
atype, box, 0, dummy_nlist, request_deriv);
222+
compute(global_tensor, force, virial, atom_tensor, atom_virial, coord, atype,
223+
box, 0, dummy_nlist, request_deriv);
224224
}
225225

226226
template <typename VALUETYPE>
227-
void DeepTensorPT::compute_inner(std::vector<VALUETYPE>& global_tensor,
228-
std::vector<VALUETYPE>& force,
229-
std::vector<VALUETYPE>& virial,
230-
std::vector<VALUETYPE>& atom_tensor,
231-
std::vector<VALUETYPE>& atom_virial,
232-
const std::vector<VALUETYPE>& coord,
233-
const std::vector<int>& atype,
234-
const std::vector<VALUETYPE>& box,
235-
const int nghost,
236-
const InputNlist& lmp_list,
237-
const bool request_deriv) {
227+
void DeepTensorPT::compute(std::vector<VALUETYPE>& global_tensor,
228+
std::vector<VALUETYPE>& force,
229+
std::vector<VALUETYPE>& virial,
230+
std::vector<VALUETYPE>& atom_tensor,
231+
std::vector<VALUETYPE>& atom_virial,
232+
const std::vector<VALUETYPE>& coord,
233+
const std::vector<int>& atype,
234+
const std::vector<VALUETYPE>& box,
235+
const int nghost,
236+
const InputNlist& lmp_list,
237+
const bool request_deriv) {
238238
torch::Device device(torch::kCUDA, gpu_id);
239239
if (!gpu_enabled) {
240240
device = torch::Device(torch::kCPU);
@@ -397,8 +397,8 @@ void DeepTensorPT::computew(std::vector<double>& global_tensor,
397397
const std::vector<double>& box,
398398
const bool request_deriv) {
399399
translate_error([&] {
400-
compute_inner(global_tensor, force, virial, atom_tensor, atom_virial, coord,
401-
atype, box, request_deriv);
400+
compute(global_tensor, force, virial, atom_tensor, atom_virial, coord,
401+
atype, box, request_deriv);
402402
});
403403
}
404404

@@ -412,8 +412,8 @@ void DeepTensorPT::computew(std::vector<float>& global_tensor,
412412
const std::vector<float>& box,
413413
const bool request_deriv) {
414414
translate_error([&] {
415-
compute_inner(global_tensor, force, virial, atom_tensor, atom_virial, coord,
416-
atype, box, request_deriv);
415+
compute(global_tensor, force, virial, atom_tensor, atom_virial, coord,
416+
atype, box, request_deriv);
417417
});
418418
}
419419

@@ -429,8 +429,8 @@ void DeepTensorPT::computew(std::vector<double>& global_tensor,
429429
const InputNlist& inlist,
430430
const bool request_deriv) {
431431
translate_error([&] {
432-
compute_inner(global_tensor, force, virial, atom_tensor, atom_virial, coord,
433-
atype, box, nghost, inlist, request_deriv);
432+
compute(global_tensor, force, virial, atom_tensor, atom_virial, coord,
433+
atype, box, nghost, inlist, request_deriv);
434434
});
435435
}
436436

@@ -446,8 +446,8 @@ void DeepTensorPT::computew(std::vector<float>& global_tensor,
446446
const InputNlist& inlist,
447447
const bool request_deriv) {
448448
translate_error([&] {
449-
compute_inner(global_tensor, force, virial, atom_tensor, atom_virial, coord,
450-
atype, box, nghost, inlist, request_deriv);
449+
compute(global_tensor, force, virial, atom_tensor, atom_virial, coord,
450+
atype, box, nghost, inlist, request_deriv);
451451
});
452452
}
453453

0 commit comments

Comments
 (0)