Skip to content

Commit c71f369

Browse files
fix bugs in deeppotpd.cc
1 parent 07118cf commit c71f369

1 file changed

Lines changed: 36 additions & 32 deletions

File tree

source/api_cc/src/DeepPotPD.cc

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -164,15 +164,23 @@ inline void enableTimestamp(bool enable = true) {
164164
}
165165
} // namespace logg
166166

167-
std::vector<int> createNlistTensorPD(
168-
const std::vector<std::vector<int>>& data) {
169-
std::vector<int> ret;
167+
void fillNlistTensor(const std::vector<std::vector<int>>& data,
168+
std::unique_ptr<paddle_infer::Tensor> flat_tensor) {
169+
size_t total_size = 0;
170170
for (const auto& row : data) {
171-
ret.insert(ret.end(), row.begin(), row.end());
171+
total_size += row.size();
172+
}
173+
std::vector<int> flat_data;
174+
flat_data.reserve(total_size);
175+
for (const auto& row : data) {
176+
flat_data.insert(flat_data.end(), row.begin(), row.end());
172177
}
173-
return ret;
174-
}
175178

179+
int nloc = data.size();
180+
int nnei = nloc > 0 ? total_size / nloc : 0;
181+
flat_tensor->Reshape({1, nloc, nnei});
182+
flat_tensor->CopyFromCpu(flag_data.data());
183+
}
176184
DeepPotPD::DeepPotPD() : inited(false) {}
177185
DeepPotPD::DeepPotPD(const std::string& model,
178186
const int& gpu_rank,
@@ -375,16 +383,15 @@ void DeepPotPD::compute(ENERGYVTYPE& ener,
375383
auto coord_wrapped_Tensor = predictor_fl->GetInputHandle("coord");
376384
coord_wrapped_Tensor->Reshape({1, nall_real, 3});
377385
coord_wrapped_Tensor->CopyFromCpu(coord_wrapped.data());
378-
386+
std::vector<std::int64_t> atype_64(datype.begin(), datype.end());
379387
auto atype_Tensor = predictor_fl->GetInputHandle("atype");
380388
atype_Tensor->Reshape({1, nall_real});
381-
atype_Tensor->CopyFromCpu(datype.data());
382-
389+
atype_Tensor->CopyFromCpu(atype_64.data());
383390
if (ago == 0) {
384-
nlist_data.copy_from_nlist(lmp_list);
391+
nlist_data.copy_from_nlist(lmp_list, nall - nghost);
385392
nlist_data.shuffle_exclude_empty(fwd_map);
386393
nlist_data.padding();
387-
if (do_message_passing == 1 && nghost > 0) {
394+
if (do_message_passing) {
388395
auto sendproc_tensor = predictor_fl->GetInputHandle("send_proc");
389396
auto recvproc_tensor = predictor_fl->GetInputHandle("recv_proc");
390397
auto recvnum_tensor = predictor_fl->GetInputHandle("recv_num");
@@ -412,6 +419,7 @@ void DeepPotPD::compute(ENERGYVTYPE& ener,
412419
} else {
413420
sendnum_tensor->CopyFromCpu(lmp_list.sendnum);
414421
}
422+
415423
communicator_tensor->Reshape({1});
416424
if (lmp_list.world) {
417425
communicator_tensor->CopyFromCpu(static_cast<int*>(lmp_list.world));
@@ -446,23 +454,21 @@ void DeepPotPD::compute(ENERGYVTYPE& ener,
446454
this->mapping_tensor->CopyFromCpu(mapping.data());
447455
}
448456
}
449-
std::vector<int> firstneigh = createNlistTensorPD(nlist_data.jlist);
450457
this->firstneigh_tensor = predictor_fl->GetInputHandle("nlist");
451-
this->firstneigh_tensor->Reshape(
452-
{1, nloc, (int)firstneigh.size() / (int)nloc});
453-
this->firstneigh_tensor->CopyFromCpu(firstneigh.data());
458+
fillNlistTensor(nlist_data.jlist, this->firstneigh_tensor);
454459
bool do_atom_virial_tensor = atomic;
460+
std::unique_ptr<paddle_infer::Tensor> fparam_tensor;
455461
if (!fparam.empty()) {
456-
std::unique_ptr<paddle_infer::Tensor> fparam_tensor;
457462
fparam_tensor = predictor_fl->GetInputHandle("fparam");
458-
fparam_tensor->Reshape({1, static_cast<int>(fparam.size())});
463+
fparam_tensor->Reshape({1, static_cast<std::int64_t>(fparam.size())});
459464
fparam_tensor->CopyFromCpu(fparam.data());
460465
}
466+
std::unique_ptr<paddle_infer::Tensor> aparam_tensor;
461467
if (!aparam_.empty()) {
462-
std::unique_ptr<paddle_infer::Tensor> aparam_tensor;
463468
aparam_tensor = predictor_fl->GetInputHandle("aparam");
464469
aparam_tensor->Reshape(
465-
{1, lmp_list.inum, static_cast<int>(aparam_.size()) / lmp_list.inum});
470+
{1, lmp_list.inum,
471+
static_cast<std::int64_t>(aparam_.size()) / lmp_list.inum});
466472
aparam_tensor->CopyFromCpu((aparam_.data()));
467473
}
468474

@@ -510,7 +516,7 @@ void DeepPotPD::compute(ENERGYVTYPE& ener,
510516
}
511517
}
512518
template void DeepPotPD::compute<double, std::vector<ENERGYTYPE>>(
513-
std::vector<ENERGYTYPE>& dener,
519+
std::vector<ENERGYTYPE>& ener,
514520
std::vector<double>& force,
515521
std::vector<double>& virial,
516522
std::vector<double>& atom_energy,
@@ -522,11 +528,10 @@ template void DeepPotPD::compute<double, std::vector<ENERGYTYPE>>(
522528
const InputNlist& lmp_list,
523529
const int& ago,
524530
const std::vector<double>& fparam,
525-
const std::vector<double>& aparam_,
531+
const std::vector<double>& aparam,
526532
const bool atomic);
527-
528533
template void DeepPotPD::compute<float, std::vector<ENERGYTYPE>>(
529-
std::vector<ENERGYTYPE>& dener,
534+
std::vector<ENERGYTYPE>& ener,
530535
std::vector<float>& force,
531536
std::vector<float>& virial,
532537
std::vector<float>& atom_energy,
@@ -538,9 +543,8 @@ template void DeepPotPD::compute<float, std::vector<ENERGYTYPE>>(
538543
const InputNlist& lmp_list,
539544
const int& ago,
540545
const std::vector<float>& fparam,
541-
const std::vector<float>& aparam_,
546+
const std::vector<float>& aparam,
542547
const bool atomic);
543-
544548
// ENERGYVTYPE: std::vector<ENERGYTYPE> or ENERGYTYPE
545549
template <typename VALUETYPE, typename ENERGYVTYPE>
546550
void DeepPotPD::compute(ENERGYVTYPE& ener,
@@ -575,15 +579,15 @@ void DeepPotPD::compute(ENERGYVTYPE& ener,
575579
}
576580
std::unique_ptr<paddle_infer::Tensor> fparam_tensor;
577581
if (!fparam.empty()) {
578-
fparam_tensor = predictor->GetInputHandle("box");
579-
fparam_tensor->Reshape({1, static_cast<int>(fparam.size())});
582+
fparam_tensor = predictor->GetInputHandle("fparam");
583+
fparam_tensor->Reshape({1, static_cast<std::int64_t>(fparam.size())});
580584
fparam_tensor->CopyFromCpu((fparam.data()));
581585
}
582586
std::unique_ptr<paddle_infer::Tensor> aparam_tensor;
583587
if (!aparam.empty()) {
584-
aparam_tensor = predictor->GetInputHandle("box");
588+
aparam_tensor = predictor->GetInputHandle("aparam");
585589
aparam_tensor->Reshape(
586-
{1, natoms, static_cast<int>(aparam.size()) / natoms});
590+
{1, natoms, static_cast<std::int64_t>(aparam.size()) / natoms});
587591
aparam_tensor->CopyFromCpu((aparam.data()));
588592
}
589593

@@ -628,11 +632,11 @@ void DeepPotPD::compute(ENERGYVTYPE& ener,
628632

629633
template void DeepPotPD::compute<double, std::vector<ENERGYTYPE>>(
630634
std::vector<ENERGYTYPE>& ener,
631-
std::vector<double>& dforce,
635+
std::vector<double>& force,
632636
std::vector<double>& virial,
633637
std::vector<double>& atom_energy,
634638
std::vector<double>& atom_virial,
635-
const std::vector<double>& dcoord,
639+
const std::vector<double>& coord,
636640
const std::vector<int>& atype,
637641
const std::vector<double>& box,
638642
const std::vector<double>& fparam,
@@ -645,7 +649,7 @@ template void DeepPotPD::compute<float, std::vector<ENERGYTYPE>>(
645649
std::vector<float>& virial,
646650
std::vector<float>& atom_energy,
647651
std::vector<float>& atom_virial,
648-
const std::vector<float>& dcoord,
652+
const std::vector<float>& coord,
649653
const std::vector<int>& atype,
650654
const std::vector<float>& box,
651655
const std::vector<float>& fparam,

0 commit comments

Comments
 (0)