|
5 | 5 | #include <torch/csrc/jit/runtime/jit_exception.h> |
6 | 6 |
|
7 | 7 | #include <cstdint> |
| 8 | +#include <numeric> // for std::iota |
8 | 9 | #include <sstream> |
9 | 10 |
|
10 | 11 | #include "common.h" |
@@ -206,239 +207,19 @@ void DeepTensorPT::compute_inner(std::vector<VALUETYPE>& global_tensor, |
206 | 207 | const std::vector<int>& atype, |
207 | 208 | const std::vector<VALUETYPE>& box, |
208 | 209 | const bool request_deriv) { |
209 | | - torch::Device device(torch::kCUDA, gpu_id); |
210 | | - if (!gpu_enabled) { |
211 | | - device = torch::Device(torch::kCPU); |
212 | | - } |
213 | | - |
214 | | - int natoms = atype.size(); |
215 | | - auto options = torch::TensorOptions().dtype(torch::kFloat64); |
216 | | - if (std::is_same<VALUETYPE, float>::value) { |
217 | | - options = torch::TensorOptions().dtype(torch::kFloat32); |
218 | | - } |
219 | | - auto int_option = |
220 | | - torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64); |
221 | | - |
222 | | - // Convert inputs to tensors |
223 | | - std::vector<VALUETYPE> coord_wrapped = coord; |
224 | | - at::Tensor coord_tensor = |
225 | | - torch::from_blob(coord_wrapped.data(), {1, natoms, 3}, options) |
226 | | - .to(device); |
227 | | - |
228 | | - std::vector<std::int64_t> atype_64(atype.begin(), atype.end()); |
229 | | - at::Tensor atype_tensor = |
230 | | - torch::from_blob(atype_64.data(), {1, natoms}, int_option).to(device); |
231 | | - |
232 | | - std::vector<VALUETYPE> box_wrapped = box; |
233 | | - at::Tensor box_tensor = |
234 | | - torch::from_blob(box_wrapped.data(), {1, 9}, options).to(device); |
235 | | - |
236 | | - // Create input vector |
237 | | - std::vector<torch::jit::IValue> inputs; |
238 | | - inputs.push_back(coord_tensor); |
239 | | - inputs.push_back(atype_tensor); |
240 | | - inputs.push_back(box_tensor); |
241 | | - |
242 | | - // Forward pass through model |
243 | | - torch::jit::IValue result; |
244 | | - if (request_deriv) { |
245 | | - inputs.push_back(torch::tensor(true)); // do_atomic_virial |
246 | | - result = module.forward(inputs); |
247 | | - } else { |
248 | | - result = module.forward(inputs); |
249 | | - } |
250 | | - |
251 | | - auto result_dict = result.toGenericDict(); |
252 | | - |
253 | | - // Extract results - try common key names |
254 | | - torch::Tensor global_tensor_tensor, atom_tensor_tensor; |
255 | | - |
256 | | - // Try different possible keys for global tensor |
257 | | - if (result_dict.contains("global_tensor")) { |
258 | | - global_tensor_tensor = result_dict.at("global_tensor").toTensor().cpu(); |
259 | | - } else if (result_dict.contains("tensor")) { |
260 | | - global_tensor_tensor = result_dict.at("tensor").toTensor().cpu(); |
261 | | - } else if (result_dict.contains("global_dipole")) { |
262 | | - global_tensor_tensor = result_dict.at("global_dipole").toTensor().cpu(); |
263 | | - } else if (result_dict.contains("dipole")) { |
264 | | - // For models that only output atomic tensor, sum to get global |
265 | | - auto dipole_tensor = result_dict.at("dipole").toTensor().cpu(); |
266 | | - global_tensor_tensor = |
267 | | - torch::sum(dipole_tensor, 1, true); // Sum over atoms, keep dims |
268 | | - } else { |
269 | | - throw deepmd::deepmd_exception( |
270 | | - "PyTorch tensor model output missing global tensor (expected " |
271 | | - "'global_tensor', 'tensor', 'global_dipole', or 'dipole' key)"); |
272 | | - } |
273 | | - |
274 | | - // Try different possible keys for atomic tensor |
275 | | - if (result_dict.contains("atomic_tensor")) { |
276 | | - atom_tensor_tensor = result_dict.at("atomic_tensor").toTensor().cpu(); |
277 | | - } else if (result_dict.contains("atom_tensor")) { |
278 | | - atom_tensor_tensor = result_dict.at("atom_tensor").toTensor().cpu(); |
279 | | - } else if (result_dict.contains("dipole")) { |
280 | | - atom_tensor_tensor = result_dict.at("dipole").toTensor().cpu(); |
281 | | - } else { |
282 | | - throw deepmd::deepmd_exception( |
283 | | - "PyTorch tensor model output missing atomic tensor (expected " |
284 | | - "'atomic_tensor', 'atom_tensor', or 'dipole' key)"); |
285 | | - } |
286 | | - |
287 | | - // Determine task dimension if not already known |
288 | | - if (odim == -1) { |
289 | | - if (global_tensor_tensor.dim() >= 2) { |
290 | | - odim = global_tensor_tensor.size(-1); |
291 | | - } else if (atom_tensor_tensor.dim() >= 3) { |
292 | | - odim = atom_tensor_tensor.size(-1); |
293 | | - } else { |
294 | | - throw deepmd::deepmd_exception( |
295 | | - "Unable to determine task dimension from model output"); |
296 | | - } |
297 | | - } |
298 | | - |
299 | | - // Copy global tensor - convert to desired type |
300 | | - global_tensor.resize(odim); |
301 | | - torch::Tensor global_tensor_converted; |
302 | | - if (std::is_same<VALUETYPE, float>::value) { |
303 | | - global_tensor_converted = global_tensor_tensor.to(torch::kFloat32); |
304 | | - auto global_tensor_acc = global_tensor_converted.accessor<float, 2>(); |
305 | | - for (int i = 0; i < odim; ++i) { |
306 | | - global_tensor[i] = global_tensor_acc[0][i]; |
307 | | - } |
308 | | - } else { |
309 | | - global_tensor_converted = global_tensor_tensor.to(torch::kFloat64); |
310 | | - auto global_tensor_acc = global_tensor_converted.accessor<double, 2>(); |
311 | | - for (int i = 0; i < odim; ++i) { |
312 | | - global_tensor[i] = global_tensor_acc[0][i]; |
313 | | - } |
314 | | - } |
315 | | - |
316 | | - // Copy atom tensor - convert to desired type |
317 | | - atom_tensor.resize(static_cast<size_t>(natoms) * static_cast<size_t>(odim)); |
318 | | - torch::Tensor atom_tensor_converted; |
319 | | - if (std::is_same<VALUETYPE, float>::value) { |
320 | | - atom_tensor_converted = atom_tensor_tensor.to(torch::kFloat32); |
321 | | - auto atom_tensor_acc = atom_tensor_converted.accessor<float, 3>(); |
322 | | - for (int i = 0; i < natoms; ++i) { |
323 | | - for (int j = 0; j < odim; ++j) { |
324 | | - atom_tensor[i * odim + j] = atom_tensor_acc[0][i][j]; |
325 | | - } |
326 | | - } |
327 | | - } else { |
328 | | - atom_tensor_converted = atom_tensor_tensor.to(torch::kFloat64); |
329 | | - auto atom_tensor_acc = atom_tensor_converted.accessor<double, 3>(); |
330 | | - for (int i = 0; i < natoms; ++i) { |
331 | | - for (int j = 0; j < odim; ++j) { |
332 | | - atom_tensor[i * odim + j] = atom_tensor_acc[0][i][j]; |
333 | | - } |
334 | | - } |
335 | | - } |
336 | | - |
337 | | - if (request_deriv) { |
338 | | - // Try to get derivative tensors with error handling |
339 | | - torch::Tensor force_tensor, virial_tensor, atom_virial_tensor; |
340 | | - |
341 | | - if (result_dict.contains("force")) { |
342 | | - force_tensor = result_dict.at("force").toTensor().cpu(); |
343 | | - } else { |
344 | | - throw deepmd::deepmd_exception( |
345 | | - "PyTorch tensor model output missing force tensor when derivatives " |
346 | | - "requested"); |
347 | | - } |
348 | | - |
349 | | - if (result_dict.contains("virial")) { |
350 | | - virial_tensor = result_dict.at("virial").toTensor().cpu(); |
351 | | - } else { |
352 | | - throw deepmd::deepmd_exception( |
353 | | - "PyTorch tensor model output missing virial tensor when derivatives " |
354 | | - "requested"); |
355 | | - } |
356 | | - |
357 | | - if (result_dict.contains("atomic_virial")) { |
358 | | - atom_virial_tensor = result_dict.at("atomic_virial").toTensor().cpu(); |
359 | | - } else if (result_dict.contains("atom_virial")) { |
360 | | - atom_virial_tensor = result_dict.at("atom_virial").toTensor().cpu(); |
361 | | - } else { |
362 | | - // Fill with zeros when atomic virial is not available |
363 | | - // This may happen with some models that don't compute atomic virial |
364 | | - atom_virial_tensor = |
365 | | - torch::zeros({1, odim, natoms, 9}, virial_tensor.options()); |
366 | | - } |
367 | | - |
368 | | - // Copy force - convert to desired type |
369 | | - force.resize(static_cast<size_t>(natoms) * 3 * static_cast<size_t>(odim)); |
370 | | - torch::Tensor force_converted; |
371 | | - if (std::is_same<VALUETYPE, float>::value) { |
372 | | - force_converted = force_tensor.to(torch::kFloat32); |
373 | | - auto force_acc = force_converted.accessor<float, 4>(); |
374 | | - for (int d = 0; d < odim; ++d) { |
375 | | - for (int i = 0; i < natoms; ++i) { |
376 | | - for (int j = 0; j < 3; ++j) { |
377 | | - force[d * natoms * 3 + i * 3 + j] = force_acc[0][d][i][j]; |
378 | | - } |
379 | | - } |
380 | | - } |
381 | | - } else { |
382 | | - force_converted = force_tensor.to(torch::kFloat64); |
383 | | - auto force_acc = force_converted.accessor<double, 4>(); |
384 | | - for (int d = 0; d < odim; ++d) { |
385 | | - for (int i = 0; i < natoms; ++i) { |
386 | | - for (int j = 0; j < 3; ++j) { |
387 | | - force[d * natoms * 3 + i * 3 + j] = force_acc[0][d][i][j]; |
388 | | - } |
389 | | - } |
390 | | - } |
391 | | - } |
392 | | - |
393 | | - // Copy virial - convert to desired type |
394 | | - virial.resize(odim * 9); |
395 | | - torch::Tensor virial_converted; |
396 | | - if (std::is_same<VALUETYPE, float>::value) { |
397 | | - virial_converted = virial_tensor.to(torch::kFloat32); |
398 | | - auto virial_acc = virial_converted.accessor<float, 3>(); |
399 | | - for (int d = 0; d < odim; ++d) { |
400 | | - for (int i = 0; i < 9; ++i) { |
401 | | - virial[d * 9 + i] = virial_acc[0][d][i]; |
402 | | - } |
403 | | - } |
404 | | - } else { |
405 | | - virial_converted = virial_tensor.to(torch::kFloat64); |
406 | | - auto virial_acc = virial_converted.accessor<double, 3>(); |
407 | | - for (int d = 0; d < odim; ++d) { |
408 | | - for (int i = 0; i < 9; ++i) { |
409 | | - virial[d * 9 + i] = virial_acc[0][d][i]; |
410 | | - } |
411 | | - } |
412 | | - } |
413 | | - |
414 | | - // Copy atom virial - convert to desired type |
415 | | - atom_virial.resize(static_cast<size_t>(natoms) * 9 * |
416 | | - static_cast<size_t>(odim)); |
417 | | - torch::Tensor atom_virial_converted; |
418 | | - if (std::is_same<VALUETYPE, float>::value) { |
419 | | - atom_virial_converted = atom_virial_tensor.to(torch::kFloat32); |
420 | | - auto atom_virial_acc = atom_virial_converted.accessor<float, 4>(); |
421 | | - for (int d = 0; d < odim; ++d) { |
422 | | - for (int i = 0; i < natoms; ++i) { |
423 | | - for (int j = 0; j < 9; ++j) { |
424 | | - atom_virial[d * natoms * 9 + i * 9 + j] = |
425 | | - atom_virial_acc[0][d][i][j]; |
426 | | - } |
427 | | - } |
428 | | - } |
429 | | - } else { |
430 | | - atom_virial_converted = atom_virial_tensor.to(torch::kFloat64); |
431 | | - auto atom_virial_acc = atom_virial_converted.accessor<double, 4>(); |
432 | | - for (int d = 0; d < odim; ++d) { |
433 | | - for (int i = 0; i < natoms; ++i) { |
434 | | - for (int j = 0; j < 9; ++j) { |
435 | | - atom_virial[d * natoms * 9 + i * 9 + j] = |
436 | | - atom_virial_acc[0][d][i][j]; |
437 | | - } |
438 | | - } |
439 | | - } |
440 | | - } |
441 | | - } |
| 210 | + // This is the simpler version without neighbor list optimization |
| 211 | + // Use a dummy neighbor list and call the full version |
| 212 | + deepmd::InputNlist dummy_nlist; |
| 213 | + // Initialize dummy neighbor list with empty data |
| 214 | + dummy_nlist.inum = atype.size(); |
| 215 | + dummy_nlist.ilist.resize(dummy_nlist.inum); |
| 216 | + std::iota(dummy_nlist.ilist.begin(), dummy_nlist.ilist.end(), 0); |
| 217 | + dummy_nlist.numneigh.resize(dummy_nlist.inum, 0); |
| 218 | + dummy_nlist.firstneigh.resize(dummy_nlist.inum); |
| 219 | + |
| 220 | + // Call the neighbor list version with nghost=0 and empty neighbor list |
| 221 | + compute_inner(global_tensor, force, virial, atom_tensor, atom_virial, coord, |
| 222 | + atype, box, 0, dummy_nlist, request_deriv); |
442 | 223 | } |
443 | 224 |
|
444 | 225 | template <typename VALUETYPE> |
|
0 commit comments