|
13 | 13 |
|
14 | 14 | using namespace deepmd; |
15 | 15 |
|
| 16 | +torch::Tensor createNlistTensor(const std::vector<std::vector<int>>& data) { |
| 17 | + size_t total_size = 0; |
| 18 | + for (const auto& row : data) { |
| 19 | + total_size += row.size(); |
| 20 | + } |
| 21 | + std::vector<int> flat_data; |
| 22 | + flat_data.reserve(total_size); |
| 23 | + for (const auto& row : data) { |
| 24 | + flat_data.insert(flat_data.end(), row.begin(), row.end()); |
| 25 | + } |
| 26 | + |
| 27 | + torch::Tensor flat_tensor = torch::tensor(flat_data, torch::kInt32); |
| 28 | + int nloc = data.size(); |
| 29 | + int nnei = nloc > 0 ? total_size / nloc : 0; |
| 30 | + return flat_tensor.view({1, nloc, nnei}); |
| 31 | +} |
| 32 | + |
16 | 33 | void DeepTensorPT::translate_error(std::function<void()> f) { |
17 | 34 | try { |
18 | 35 | f(); |
@@ -434,13 +451,157 @@ void DeepTensorPT::compute_inner(std::vector<VALUETYPE>& global_tensor, |
434 | 451 | const std::vector<int>& atype, |
435 | 452 | const std::vector<VALUETYPE>& box, |
436 | 453 | const int nghost, |
437 | | - const InputNlist& inlist, |
| 454 | + const InputNlist& lmp_list, |
438 | 455 | const bool request_deriv) { |
439 | | - // Implement neighbor list support following DeepPotPT pattern |
440 | | - // For now, use the simple compute_inner approach |
441 | | - // TODO: Add full neighbor list optimization for better performance |
442 | | - compute_inner(global_tensor, force, virial, atom_tensor, atom_virial, coord, |
443 | | - atype, box, request_deriv); |
| 456 | + torch::Device device(torch::kCUDA, gpu_id); |
| 457 | + if (!gpu_enabled) { |
| 458 | + device = torch::Device(torch::kCPU); |
| 459 | + } |
| 460 | + |
| 461 | + int natoms = atype.size(); |
| 462 | + auto options = torch::TensorOptions().dtype(torch::kFloat64); |
| 463 | + torch::ScalarType floatType = torch::kFloat64; |
| 464 | + if (std::is_same<VALUETYPE, float>::value) { |
| 465 | + options = torch::TensorOptions().dtype(torch::kFloat32); |
| 466 | + floatType = torch::kFloat32; |
| 467 | + } |
| 468 | + auto int32_option = |
| 469 | + torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt32); |
| 470 | + auto int_option = |
| 471 | + torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64); |
| 472 | + |
| 473 | + // Select real atoms following DeepPotPT pattern |
| 474 | + std::vector<VALUETYPE> dcoord, aparam_; |
| 475 | + std::vector<int> datype, fwd_map, bkw_map; |
| 476 | + int nghost_real, nall_real, nloc_real; |
| 477 | + int nall = natoms; |
| 478 | + int nframes = 1; |
| 479 | + std::vector<VALUETYPE> aparam; // Empty for tensor models |
| 480 | + select_real_atoms_coord(dcoord, datype, aparam_, nghost_real, fwd_map, |
| 481 | + bkw_map, nall_real, nloc_real, coord, atype, aparam, |
| 482 | + nghost, ntypes, nframes, 0, nall, false); |
| 483 | + int nloc = nall_real - nghost_real; |
| 484 | + |
| 485 | + std::vector<VALUETYPE> coord_wrapped = dcoord; |
| 486 | + at::Tensor coord_wrapped_Tensor = |
| 487 | + torch::from_blob(coord_wrapped.data(), {1, nall_real, 3}, options) |
| 488 | + .to(device); |
| 489 | + std::vector<std::int64_t> atype_64(datype.begin(), datype.end()); |
| 490 | + at::Tensor atype_Tensor = |
| 491 | + torch::from_blob(atype_64.data(), {1, nall_real}, int_option).to(device); |
| 492 | + |
| 493 | + // Process neighbor list following DeepPotPT pattern |
| 494 | + nlist_data.copy_from_nlist(lmp_list, nall - nghost); |
| 495 | + nlist_data.shuffle_exclude_empty(fwd_map); |
| 496 | + nlist_data.padding(); |
| 497 | + |
| 498 | + at::Tensor firstneigh = createNlistTensor(nlist_data.jlist); |
| 499 | + firstneigh_tensor = firstneigh.to(torch::kInt64).to(device); |
| 500 | + |
| 501 | + // Prepare box tensor |
| 502 | + std::vector<VALUETYPE> box_wrapped = box; |
| 503 | + at::Tensor box_tensor = |
| 504 | + torch::from_blob(box_wrapped.data(), {1, 9}, options).to(device); |
| 505 | + |
| 506 | + // Create input vector for model |
| 507 | + std::vector<torch::jit::IValue> inputs; |
| 508 | + inputs.push_back(coord_wrapped_Tensor); |
| 509 | + inputs.push_back(atype_Tensor); |
| 510 | + inputs.push_back(firstneigh_tensor); |
| 511 | + inputs.push_back(box_tensor); |
| 512 | + |
| 513 | + bool do_atom_virial_tensor = request_deriv; |
| 514 | + inputs.push_back(do_atom_virial_tensor); |
| 515 | + |
| 516 | + // Forward pass through model |
| 517 | + c10::Dict<c10::IValue, c10::IValue> outputs = |
| 518 | + module.forward(inputs).toGenericDict(); |
| 519 | + |
| 520 | + // Process global tensor |
| 521 | + if (outputs.contains("global_tensor") || outputs.contains("dipole") || |
| 522 | + outputs.contains("global_dipole")) { |
| 523 | + c10::IValue tensor_out; |
| 524 | + if (outputs.contains("global_tensor")) { |
| 525 | + tensor_out = outputs.at("global_tensor"); |
| 526 | + } else if (outputs.contains("global_dipole")) { |
| 527 | + tensor_out = outputs.at("global_dipole"); |
| 528 | + } else { |
| 529 | + tensor_out = outputs.at("dipole"); |
| 530 | + } |
| 531 | + |
| 532 | + torch::Tensor flat_tensor = tensor_out.toTensor().view({-1}).to(floatType); |
| 533 | + torch::Tensor cpu_tensor = flat_tensor.to(torch::kCPU); |
| 534 | + global_tensor.assign(cpu_tensor.data_ptr<VALUETYPE>(), |
| 535 | + cpu_tensor.data_ptr<VALUETYPE>() + cpu_tensor.numel()); |
| 536 | + } |
| 537 | + |
| 538 | + // Process force if available |
| 539 | + if (outputs.contains("force") || outputs.contains("extended_force")) { |
| 540 | + c10::IValue force_out = outputs.contains("extended_force") |
| 541 | + ? outputs.at("extended_force") |
| 542 | + : outputs.at("force"); |
| 543 | + torch::Tensor flat_force = force_out.toTensor().view({-1}).to(floatType); |
| 544 | + torch::Tensor cpu_force = flat_force.to(torch::kCPU); |
| 545 | + std::vector<VALUETYPE> dforce; |
| 546 | + dforce.assign(cpu_force.data_ptr<VALUETYPE>(), |
| 547 | + cpu_force.data_ptr<VALUETYPE>() + cpu_force.numel()); |
| 548 | + |
| 549 | + // Map back to original atom order using select_map |
| 550 | + force.resize(static_cast<size_t>(nframes) * fwd_map.size() * odim * 3); |
| 551 | + select_map<VALUETYPE>(force, dforce, bkw_map, odim * 3, nframes, |
| 552 | + fwd_map.size(), nall_real); |
| 553 | + } |
| 554 | + |
| 555 | + // Process virial if available |
| 556 | + if (outputs.contains("virial")) { |
| 557 | + c10::IValue virial_out = outputs.at("virial"); |
| 558 | + torch::Tensor flat_virial = virial_out.toTensor().view({-1}).to(floatType); |
| 559 | + torch::Tensor cpu_virial = flat_virial.to(torch::kCPU); |
| 560 | + virial.assign(cpu_virial.data_ptr<VALUETYPE>(), |
| 561 | + cpu_virial.data_ptr<VALUETYPE>() + cpu_virial.numel()); |
| 562 | + } |
| 563 | + |
| 564 | + // Process atom tensor if available |
| 565 | + if (outputs.contains("atom_tensor")) { |
| 566 | + c10::IValue atom_tensor_out = outputs.at("atom_tensor"); |
| 567 | + torch::Tensor flat_atom_tensor = |
| 568 | + atom_tensor_out.toTensor().view({-1}).to(floatType); |
| 569 | + torch::Tensor cpu_atom_tensor = flat_atom_tensor.to(torch::kCPU); |
| 570 | + std::vector<VALUETYPE> datom_tensor_tmp; |
| 571 | + datom_tensor_tmp.assign( |
| 572 | + cpu_atom_tensor.data_ptr<VALUETYPE>(), |
| 573 | + cpu_atom_tensor.data_ptr<VALUETYPE>() + cpu_atom_tensor.numel()); |
| 574 | + |
| 575 | + // Map back to original atom order using select_map |
| 576 | + atom_tensor.resize(static_cast<size_t>(nframes) * fwd_map.size() * odim); |
| 577 | + select_map<VALUETYPE>(atom_tensor, datom_tensor_tmp, bkw_map, odim, nframes, |
| 578 | + fwd_map.size(), nall_real); |
| 579 | + } |
| 580 | + |
| 581 | + // Process atomic virial if requested and available |
| 582 | + if (request_deriv && (outputs.contains("atom_virial") || |
| 583 | + outputs.contains("extended_virial"))) { |
| 584 | + c10::IValue atom_virial_out = outputs.contains("extended_virial") |
| 585 | + ? outputs.at("extended_virial") |
| 586 | + : outputs.at("atom_virial"); |
| 587 | + torch::Tensor flat_atom_virial = |
| 588 | + atom_virial_out.toTensor().view({-1}).to(floatType); |
| 589 | + torch::Tensor cpu_atom_virial = flat_atom_virial.to(torch::kCPU); |
| 590 | + std::vector<VALUETYPE> datom_virial_tmp; |
| 591 | + datom_virial_tmp.assign( |
| 592 | + cpu_atom_virial.data_ptr<VALUETYPE>(), |
| 593 | + cpu_atom_virial.data_ptr<VALUETYPE>() + cpu_atom_virial.numel()); |
| 594 | + |
| 595 | + // Map back to original atom order using select_map |
| 596 | + atom_virial.resize(static_cast<size_t>(nframes) * fwd_map.size() * odim * |
| 597 | + 9); |
| 598 | + select_map<VALUETYPE>(atom_virial, datom_virial_tmp, bkw_map, odim * 9, |
| 599 | + nframes, fwd_map.size(), nall_real); |
| 600 | + } else if (request_deriv) { |
| 601 | + // Fill with zeros if atomic virial not available but requested |
| 602 | + atom_virial.assign(static_cast<size_t>(natoms) * odim * 9, |
| 603 | + static_cast<VALUETYPE>(0.0)); |
| 604 | + } |
444 | 605 | } |
445 | 606 |
|
446 | 607 | // Public wrapper functions |
|
0 commit comments