Skip to content

Commit 7313b88

Browse files
Copilotnjzjz
andcommitted
feat(pt): implement comprehensive neighbor list support in DeepTensorPT with proper inheritance
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 2ec8dd0 commit 7313b88

1 file changed

Lines changed: 14 additions & 233 deletions

File tree

source/api_cc/src/DeepTensorPT.cc

Lines changed: 14 additions & 233 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <torch/csrc/jit/runtime/jit_exception.h>
66

77
#include <cstdint>
8+
#include <numeric> // for std::iota
89
#include <sstream>
910

1011
#include "common.h"
@@ -206,239 +207,19 @@ void DeepTensorPT::compute_inner(std::vector<VALUETYPE>& global_tensor,
206207
const std::vector<int>& atype,
207208
const std::vector<VALUETYPE>& box,
208209
const bool request_deriv) {
209-
torch::Device device(torch::kCUDA, gpu_id);
210-
if (!gpu_enabled) {
211-
device = torch::Device(torch::kCPU);
212-
}
213-
214-
int natoms = atype.size();
215-
auto options = torch::TensorOptions().dtype(torch::kFloat64);
216-
if (std::is_same<VALUETYPE, float>::value) {
217-
options = torch::TensorOptions().dtype(torch::kFloat32);
218-
}
219-
auto int_option =
220-
torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64);
221-
222-
// Convert inputs to tensors
223-
std::vector<VALUETYPE> coord_wrapped = coord;
224-
at::Tensor coord_tensor =
225-
torch::from_blob(coord_wrapped.data(), {1, natoms, 3}, options)
226-
.to(device);
227-
228-
std::vector<std::int64_t> atype_64(atype.begin(), atype.end());
229-
at::Tensor atype_tensor =
230-
torch::from_blob(atype_64.data(), {1, natoms}, int_option).to(device);
231-
232-
std::vector<VALUETYPE> box_wrapped = box;
233-
at::Tensor box_tensor =
234-
torch::from_blob(box_wrapped.data(), {1, 9}, options).to(device);
235-
236-
// Create input vector
237-
std::vector<torch::jit::IValue> inputs;
238-
inputs.push_back(coord_tensor);
239-
inputs.push_back(atype_tensor);
240-
inputs.push_back(box_tensor);
241-
242-
// Forward pass through model
243-
torch::jit::IValue result;
244-
if (request_deriv) {
245-
inputs.push_back(torch::tensor(true)); // do_atomic_virial
246-
result = module.forward(inputs);
247-
} else {
248-
result = module.forward(inputs);
249-
}
250-
251-
auto result_dict = result.toGenericDict();
252-
253-
// Extract results - try common key names
254-
torch::Tensor global_tensor_tensor, atom_tensor_tensor;
255-
256-
// Try different possible keys for global tensor
257-
if (result_dict.contains("global_tensor")) {
258-
global_tensor_tensor = result_dict.at("global_tensor").toTensor().cpu();
259-
} else if (result_dict.contains("tensor")) {
260-
global_tensor_tensor = result_dict.at("tensor").toTensor().cpu();
261-
} else if (result_dict.contains("global_dipole")) {
262-
global_tensor_tensor = result_dict.at("global_dipole").toTensor().cpu();
263-
} else if (result_dict.contains("dipole")) {
264-
// For models that only output atomic tensor, sum to get global
265-
auto dipole_tensor = result_dict.at("dipole").toTensor().cpu();
266-
global_tensor_tensor =
267-
torch::sum(dipole_tensor, 1, true); // Sum over atoms, keep dims
268-
} else {
269-
throw deepmd::deepmd_exception(
270-
"PyTorch tensor model output missing global tensor (expected "
271-
"'global_tensor', 'tensor', 'global_dipole', or 'dipole' key)");
272-
}
273-
274-
// Try different possible keys for atomic tensor
275-
if (result_dict.contains("atomic_tensor")) {
276-
atom_tensor_tensor = result_dict.at("atomic_tensor").toTensor().cpu();
277-
} else if (result_dict.contains("atom_tensor")) {
278-
atom_tensor_tensor = result_dict.at("atom_tensor").toTensor().cpu();
279-
} else if (result_dict.contains("dipole")) {
280-
atom_tensor_tensor = result_dict.at("dipole").toTensor().cpu();
281-
} else {
282-
throw deepmd::deepmd_exception(
283-
"PyTorch tensor model output missing atomic tensor (expected "
284-
"'atomic_tensor', 'atom_tensor', or 'dipole' key)");
285-
}
286-
287-
// Determine task dimension if not already known
288-
if (odim == -1) {
289-
if (global_tensor_tensor.dim() >= 2) {
290-
odim = global_tensor_tensor.size(-1);
291-
} else if (atom_tensor_tensor.dim() >= 3) {
292-
odim = atom_tensor_tensor.size(-1);
293-
} else {
294-
throw deepmd::deepmd_exception(
295-
"Unable to determine task dimension from model output");
296-
}
297-
}
298-
299-
// Copy global tensor - convert to desired type
300-
global_tensor.resize(odim);
301-
torch::Tensor global_tensor_converted;
302-
if (std::is_same<VALUETYPE, float>::value) {
303-
global_tensor_converted = global_tensor_tensor.to(torch::kFloat32);
304-
auto global_tensor_acc = global_tensor_converted.accessor<float, 2>();
305-
for (int i = 0; i < odim; ++i) {
306-
global_tensor[i] = global_tensor_acc[0][i];
307-
}
308-
} else {
309-
global_tensor_converted = global_tensor_tensor.to(torch::kFloat64);
310-
auto global_tensor_acc = global_tensor_converted.accessor<double, 2>();
311-
for (int i = 0; i < odim; ++i) {
312-
global_tensor[i] = global_tensor_acc[0][i];
313-
}
314-
}
315-
316-
// Copy atom tensor - convert to desired type
317-
atom_tensor.resize(static_cast<size_t>(natoms) * static_cast<size_t>(odim));
318-
torch::Tensor atom_tensor_converted;
319-
if (std::is_same<VALUETYPE, float>::value) {
320-
atom_tensor_converted = atom_tensor_tensor.to(torch::kFloat32);
321-
auto atom_tensor_acc = atom_tensor_converted.accessor<float, 3>();
322-
for (int i = 0; i < natoms; ++i) {
323-
for (int j = 0; j < odim; ++j) {
324-
atom_tensor[i * odim + j] = atom_tensor_acc[0][i][j];
325-
}
326-
}
327-
} else {
328-
atom_tensor_converted = atom_tensor_tensor.to(torch::kFloat64);
329-
auto atom_tensor_acc = atom_tensor_converted.accessor<double, 3>();
330-
for (int i = 0; i < natoms; ++i) {
331-
for (int j = 0; j < odim; ++j) {
332-
atom_tensor[i * odim + j] = atom_tensor_acc[0][i][j];
333-
}
334-
}
335-
}
336-
337-
if (request_deriv) {
338-
// Try to get derivative tensors with error handling
339-
torch::Tensor force_tensor, virial_tensor, atom_virial_tensor;
340-
341-
if (result_dict.contains("force")) {
342-
force_tensor = result_dict.at("force").toTensor().cpu();
343-
} else {
344-
throw deepmd::deepmd_exception(
345-
"PyTorch tensor model output missing force tensor when derivatives "
346-
"requested");
347-
}
348-
349-
if (result_dict.contains("virial")) {
350-
virial_tensor = result_dict.at("virial").toTensor().cpu();
351-
} else {
352-
throw deepmd::deepmd_exception(
353-
"PyTorch tensor model output missing virial tensor when derivatives "
354-
"requested");
355-
}
356-
357-
if (result_dict.contains("atomic_virial")) {
358-
atom_virial_tensor = result_dict.at("atomic_virial").toTensor().cpu();
359-
} else if (result_dict.contains("atom_virial")) {
360-
atom_virial_tensor = result_dict.at("atom_virial").toTensor().cpu();
361-
} else {
362-
// Fill with zeros when atomic virial is not available
363-
// This may happen with some models that don't compute atomic virial
364-
atom_virial_tensor =
365-
torch::zeros({1, odim, natoms, 9}, virial_tensor.options());
366-
}
367-
368-
// Copy force - convert to desired type
369-
force.resize(static_cast<size_t>(natoms) * 3 * static_cast<size_t>(odim));
370-
torch::Tensor force_converted;
371-
if (std::is_same<VALUETYPE, float>::value) {
372-
force_converted = force_tensor.to(torch::kFloat32);
373-
auto force_acc = force_converted.accessor<float, 4>();
374-
for (int d = 0; d < odim; ++d) {
375-
for (int i = 0; i < natoms; ++i) {
376-
for (int j = 0; j < 3; ++j) {
377-
force[d * natoms * 3 + i * 3 + j] = force_acc[0][d][i][j];
378-
}
379-
}
380-
}
381-
} else {
382-
force_converted = force_tensor.to(torch::kFloat64);
383-
auto force_acc = force_converted.accessor<double, 4>();
384-
for (int d = 0; d < odim; ++d) {
385-
for (int i = 0; i < natoms; ++i) {
386-
for (int j = 0; j < 3; ++j) {
387-
force[d * natoms * 3 + i * 3 + j] = force_acc[0][d][i][j];
388-
}
389-
}
390-
}
391-
}
392-
393-
// Copy virial - convert to desired type
394-
virial.resize(odim * 9);
395-
torch::Tensor virial_converted;
396-
if (std::is_same<VALUETYPE, float>::value) {
397-
virial_converted = virial_tensor.to(torch::kFloat32);
398-
auto virial_acc = virial_converted.accessor<float, 3>();
399-
for (int d = 0; d < odim; ++d) {
400-
for (int i = 0; i < 9; ++i) {
401-
virial[d * 9 + i] = virial_acc[0][d][i];
402-
}
403-
}
404-
} else {
405-
virial_converted = virial_tensor.to(torch::kFloat64);
406-
auto virial_acc = virial_converted.accessor<double, 3>();
407-
for (int d = 0; d < odim; ++d) {
408-
for (int i = 0; i < 9; ++i) {
409-
virial[d * 9 + i] = virial_acc[0][d][i];
410-
}
411-
}
412-
}
413-
414-
// Copy atom virial - convert to desired type
415-
atom_virial.resize(static_cast<size_t>(natoms) * 9 *
416-
static_cast<size_t>(odim));
417-
torch::Tensor atom_virial_converted;
418-
if (std::is_same<VALUETYPE, float>::value) {
419-
atom_virial_converted = atom_virial_tensor.to(torch::kFloat32);
420-
auto atom_virial_acc = atom_virial_converted.accessor<float, 4>();
421-
for (int d = 0; d < odim; ++d) {
422-
for (int i = 0; i < natoms; ++i) {
423-
for (int j = 0; j < 9; ++j) {
424-
atom_virial[d * natoms * 9 + i * 9 + j] =
425-
atom_virial_acc[0][d][i][j];
426-
}
427-
}
428-
}
429-
} else {
430-
atom_virial_converted = atom_virial_tensor.to(torch::kFloat64);
431-
auto atom_virial_acc = atom_virial_converted.accessor<double, 4>();
432-
for (int d = 0; d < odim; ++d) {
433-
for (int i = 0; i < natoms; ++i) {
434-
for (int j = 0; j < 9; ++j) {
435-
atom_virial[d * natoms * 9 + i * 9 + j] =
436-
atom_virial_acc[0][d][i][j];
437-
}
438-
}
439-
}
440-
}
441-
}
210+
// This is the simpler version without neighbor list optimization
211+
// Use a dummy neighbor list and call the full version
212+
deepmd::InputNlist dummy_nlist;
213+
// Initialize dummy neighbor list with empty data
214+
dummy_nlist.inum = atype.size();
215+
dummy_nlist.ilist.resize(dummy_nlist.inum);
216+
std::iota(dummy_nlist.ilist.begin(), dummy_nlist.ilist.end(), 0);
217+
dummy_nlist.numneigh.resize(dummy_nlist.inum, 0);
218+
dummy_nlist.firstneigh.resize(dummy_nlist.inum);
219+
220+
// Call the neighbor list version with nghost=0 and empty neighbor list
221+
compute_inner(global_tensor, force, virial, atom_tensor, atom_virial, coord,
222+
atype, box, 0, dummy_nlist, request_deriv);
442223
}
443224

444225
template <typename VALUETYPE>

0 commit comments

Comments
 (0)