Skip to content

Commit 2ec8dd0

Browse files
Copilotnjzjz
andcommitted
feat(pt): implement comprehensive neighbor list support in DeepTensorPT with proper inheritance
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 85dc0f6 commit 2ec8dd0

2 files changed

Lines changed: 169 additions & 6 deletions

File tree

source/api_cc/include/DeepTensorPT.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,8 @@ class DeepTensorPT : public DeepTensorBase {
246246
int gpu_id;
247247
bool gpu_enabled;
248248
NeighborListData nlist_data;
249+
// Neighbor list tensors for efficient computation
250+
at::Tensor firstneigh_tensor;
249251

250252
/**
251253
* @brief Translate PyTorch exceptions to the DeePMD-kit exception.

source/api_cc/src/DeepTensorPT.cc

Lines changed: 167 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,23 @@
1313

1414
using namespace deepmd;
1515

16+
torch::Tensor createNlistTensor(const std::vector<std::vector<int>>& data) {
17+
size_t total_size = 0;
18+
for (const auto& row : data) {
19+
total_size += row.size();
20+
}
21+
std::vector<int> flat_data;
22+
flat_data.reserve(total_size);
23+
for (const auto& row : data) {
24+
flat_data.insert(flat_data.end(), row.begin(), row.end());
25+
}
26+
27+
torch::Tensor flat_tensor = torch::tensor(flat_data, torch::kInt32);
28+
int nloc = data.size();
29+
int nnei = nloc > 0 ? total_size / nloc : 0;
30+
return flat_tensor.view({1, nloc, nnei});
31+
}
32+
1633
void DeepTensorPT::translate_error(std::function<void()> f) {
1734
try {
1835
f();
@@ -434,13 +451,157 @@ void DeepTensorPT::compute_inner(std::vector<VALUETYPE>& global_tensor,
434451
const std::vector<int>& atype,
435452
const std::vector<VALUETYPE>& box,
436453
const int nghost,
437-
const InputNlist& inlist,
454+
const InputNlist& lmp_list,
438455
const bool request_deriv) {
439-
// Implement neighbor list support following DeepPotPT pattern
440-
// For now, use the simple compute_inner approach
441-
// TODO: Add full neighbor list optimization for better performance
442-
compute_inner(global_tensor, force, virial, atom_tensor, atom_virial, coord,
443-
atype, box, request_deriv);
456+
torch::Device device(torch::kCUDA, gpu_id);
457+
if (!gpu_enabled) {
458+
device = torch::Device(torch::kCPU);
459+
}
460+
461+
int natoms = atype.size();
462+
auto options = torch::TensorOptions().dtype(torch::kFloat64);
463+
torch::ScalarType floatType = torch::kFloat64;
464+
if (std::is_same<VALUETYPE, float>::value) {
465+
options = torch::TensorOptions().dtype(torch::kFloat32);
466+
floatType = torch::kFloat32;
467+
}
468+
auto int32_option =
469+
torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt32);
470+
auto int_option =
471+
torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64);
472+
473+
// Select real atoms following DeepPotPT pattern
474+
std::vector<VALUETYPE> dcoord, aparam_;
475+
std::vector<int> datype, fwd_map, bkw_map;
476+
int nghost_real, nall_real, nloc_real;
477+
int nall = natoms;
478+
int nframes = 1;
479+
std::vector<VALUETYPE> aparam; // Empty for tensor models
480+
select_real_atoms_coord(dcoord, datype, aparam_, nghost_real, fwd_map,
481+
bkw_map, nall_real, nloc_real, coord, atype, aparam,
482+
nghost, ntypes, nframes, 0, nall, false);
483+
int nloc = nall_real - nghost_real;
484+
485+
std::vector<VALUETYPE> coord_wrapped = dcoord;
486+
at::Tensor coord_wrapped_Tensor =
487+
torch::from_blob(coord_wrapped.data(), {1, nall_real, 3}, options)
488+
.to(device);
489+
std::vector<std::int64_t> atype_64(datype.begin(), datype.end());
490+
at::Tensor atype_Tensor =
491+
torch::from_blob(atype_64.data(), {1, nall_real}, int_option).to(device);
492+
493+
// Process neighbor list following DeepPotPT pattern
494+
nlist_data.copy_from_nlist(lmp_list, nall - nghost);
495+
nlist_data.shuffle_exclude_empty(fwd_map);
496+
nlist_data.padding();
497+
498+
at::Tensor firstneigh = createNlistTensor(nlist_data.jlist);
499+
firstneigh_tensor = firstneigh.to(torch::kInt64).to(device);
500+
501+
// Prepare box tensor
502+
std::vector<VALUETYPE> box_wrapped = box;
503+
at::Tensor box_tensor =
504+
torch::from_blob(box_wrapped.data(), {1, 9}, options).to(device);
505+
506+
// Create input vector for model
507+
std::vector<torch::jit::IValue> inputs;
508+
inputs.push_back(coord_wrapped_Tensor);
509+
inputs.push_back(atype_Tensor);
510+
inputs.push_back(firstneigh_tensor);
511+
inputs.push_back(box_tensor);
512+
513+
bool do_atom_virial_tensor = request_deriv;
514+
inputs.push_back(do_atom_virial_tensor);
515+
516+
// Forward pass through model
517+
c10::Dict<c10::IValue, c10::IValue> outputs =
518+
module.forward(inputs).toGenericDict();
519+
520+
// Process global tensor
521+
if (outputs.contains("global_tensor") || outputs.contains("dipole") ||
522+
outputs.contains("global_dipole")) {
523+
c10::IValue tensor_out;
524+
if (outputs.contains("global_tensor")) {
525+
tensor_out = outputs.at("global_tensor");
526+
} else if (outputs.contains("global_dipole")) {
527+
tensor_out = outputs.at("global_dipole");
528+
} else {
529+
tensor_out = outputs.at("dipole");
530+
}
531+
532+
torch::Tensor flat_tensor = tensor_out.toTensor().view({-1}).to(floatType);
533+
torch::Tensor cpu_tensor = flat_tensor.to(torch::kCPU);
534+
global_tensor.assign(cpu_tensor.data_ptr<VALUETYPE>(),
535+
cpu_tensor.data_ptr<VALUETYPE>() + cpu_tensor.numel());
536+
}
537+
538+
// Process force if available
539+
if (outputs.contains("force") || outputs.contains("extended_force")) {
540+
c10::IValue force_out = outputs.contains("extended_force")
541+
? outputs.at("extended_force")
542+
: outputs.at("force");
543+
torch::Tensor flat_force = force_out.toTensor().view({-1}).to(floatType);
544+
torch::Tensor cpu_force = flat_force.to(torch::kCPU);
545+
std::vector<VALUETYPE> dforce;
546+
dforce.assign(cpu_force.data_ptr<VALUETYPE>(),
547+
cpu_force.data_ptr<VALUETYPE>() + cpu_force.numel());
548+
549+
// Map back to original atom order using select_map
550+
force.resize(static_cast<size_t>(nframes) * fwd_map.size() * odim * 3);
551+
select_map<VALUETYPE>(force, dforce, bkw_map, odim * 3, nframes,
552+
fwd_map.size(), nall_real);
553+
}
554+
555+
// Process virial if available
556+
if (outputs.contains("virial")) {
557+
c10::IValue virial_out = outputs.at("virial");
558+
torch::Tensor flat_virial = virial_out.toTensor().view({-1}).to(floatType);
559+
torch::Tensor cpu_virial = flat_virial.to(torch::kCPU);
560+
virial.assign(cpu_virial.data_ptr<VALUETYPE>(),
561+
cpu_virial.data_ptr<VALUETYPE>() + cpu_virial.numel());
562+
}
563+
564+
// Process atom tensor if available
565+
if (outputs.contains("atom_tensor")) {
566+
c10::IValue atom_tensor_out = outputs.at("atom_tensor");
567+
torch::Tensor flat_atom_tensor =
568+
atom_tensor_out.toTensor().view({-1}).to(floatType);
569+
torch::Tensor cpu_atom_tensor = flat_atom_tensor.to(torch::kCPU);
570+
std::vector<VALUETYPE> datom_tensor_tmp;
571+
datom_tensor_tmp.assign(
572+
cpu_atom_tensor.data_ptr<VALUETYPE>(),
573+
cpu_atom_tensor.data_ptr<VALUETYPE>() + cpu_atom_tensor.numel());
574+
575+
// Map back to original atom order using select_map
576+
atom_tensor.resize(static_cast<size_t>(nframes) * fwd_map.size() * odim);
577+
select_map<VALUETYPE>(atom_tensor, datom_tensor_tmp, bkw_map, odim, nframes,
578+
fwd_map.size(), nall_real);
579+
}
580+
581+
// Process atomic virial if requested and available
582+
if (request_deriv && (outputs.contains("atom_virial") ||
583+
outputs.contains("extended_virial"))) {
584+
c10::IValue atom_virial_out = outputs.contains("extended_virial")
585+
? outputs.at("extended_virial")
586+
: outputs.at("atom_virial");
587+
torch::Tensor flat_atom_virial =
588+
atom_virial_out.toTensor().view({-1}).to(floatType);
589+
torch::Tensor cpu_atom_virial = flat_atom_virial.to(torch::kCPU);
590+
std::vector<VALUETYPE> datom_virial_tmp;
591+
datom_virial_tmp.assign(
592+
cpu_atom_virial.data_ptr<VALUETYPE>(),
593+
cpu_atom_virial.data_ptr<VALUETYPE>() + cpu_atom_virial.numel());
594+
595+
// Map back to original atom order using select_map
596+
atom_virial.resize(static_cast<size_t>(nframes) * fwd_map.size() * odim *
597+
9);
598+
select_map<VALUETYPE>(atom_virial, datom_virial_tmp, bkw_map, odim * 9,
599+
nframes, fwd_map.size(), nall_real);
600+
} else if (request_deriv) {
601+
// Fill with zeros if atomic virial not available but requested
602+
atom_virial.assign(static_cast<size_t>(natoms) * odim * 9,
603+
static_cast<VALUETYPE>(0.0));
604+
}
444605
}
445606

446607
// Public wrapper functions

0 commit comments

Comments
 (0)