Skip to content

Commit d243804

Browse files
committed
create custom OP for se_t_tebd
1 parent 3aff58b commit d243804

3 files changed

Lines changed: 65 additions & 15 deletions

File tree

deepmd/pt/model/descriptor/se_t_tebd.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,7 @@ def forward(
904904
self.rcut_smth,
905905
protection=self.env_protection,
906906
)
907+
# dmatrix: [1/r, dx/r^2, dy/r^2, dz/r^2], sw: distance weighting
907908
# nb x nloc x nnei
908909
exclude_mask = self.emask(nlist, extended_atype)
909910
nlist = torch.where(exclude_mask != 0, nlist, -1)
@@ -924,11 +925,11 @@ def forward(
924925
rr = dmatrix
925926
rr = rr * exclude_mask[:, :, None]
926927

927-
# nfnl x nt_i x 3
928+
# nfnl x nt_i x 3: direction vectors
928929
rr_i = rr[:, :, 1:]
929930
# nfnl x nt_j x 3
930931
rr_j = rr[:, :, 1:]
931-
# nfnl x nt_i x nt_j
932+
# nfnl x nt_i x nt_j: three-body angular correlations (cos theta_ij)
932933
env_ij = torch.einsum("ijm,ikm->ijk", rr_i, rr_j)
933934
# nfnl x nt_i x nt_j x 1
934935
ss = env_ij.unsqueeze(-1)
@@ -951,18 +952,19 @@ def forward(
951952
gg = self.filter_layers.networks[0](ss)
952953
elif self.tebd_input_mode in ["strip"]:
953954
if self.compress:
954-
# Use tabulated computation for the geometric embedding
955+
# Tabulated geometric embedding from angular features
956+
# using SE_T_TEBD specific function
955957
ebd_env_ij = env_ij.view(-1, 1)
956-
gg_s_compressed = torch.ops.deepmd.tabulate_fusion_se_t(
958+
gg_s = torch.ops.deepmd.tabulate_fusion_se_t_tebd(
957959
self.compress_data[0].contiguous(),
958960
self.compress_info[0].cpu().contiguous(),
959961
ebd_env_ij.contiguous(),
960962
env_ij.contiguous(),
961963
self.filter_neuron[-1],
962964
)[0]
963-
# The compressed output is nfnl x ng, need to expand to nfnl x nt_i x nt_j x ng
964-
# by replicating across the neighbor dimensions
965-
gg_s = gg_s_compressed.view(nfnl, 1, 1, self.filter_neuron[-1]).expand(nfnl, nnei, nnei, self.filter_neuron[-1])
965+
# SE_T_TEBD tabulation preserves the full neighbor structure
966+
# nfnl x nt_i x nt_j x ng
967+
gg_s = gg_s.view(nfnl, nnei, nnei, self.filter_neuron[-1])
966968
else:
967969
# nfnl x nt_i x nt_j x ng
968970
gg_s = self.filter_layers.networks[0](ss)
@@ -1010,16 +1012,19 @@ def forward(
10101012
# (nfnl x nt_i x nt_j) x ng
10111013
gg_t = gg_t.reshape(nfnl, nnei, nnei, ng)
10121014
if self.smooth:
1015+
# Apply distance weighting to type features
10131016
gg_t = (
10141017
gg_t
10151018
* sw.reshape(nfnl, self.nnei, 1, 1)
10161019
* sw.reshape(nfnl, 1, self.nnei, 1)
10171020
)
1021+
# Combine geometric and type embeddings: gg_s * (1 + gg_t)
10181022
# nfnl x nt_i x nt_j x ng
10191023
gg = gg_s * gg_t + gg_s
10201024
else:
10211025
raise NotImplementedError
10221026

1027+
# Contract angular correlations with learned features
10231028
# nfnl x ng
10241029
res_ij = torch.einsum("ijk,ijkm->im", env_ij, gg)
10251030
res_ij = res_ij * (1.0 / float(self.nnei) / float(self.nnei))

source/lib/src/gpu/tabulate.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial(
630630
dz_dy[block_idx * last_layer_size + thread_idx] = sum;
631631
}
632632

633-
template <typename FPTYPE, int MM, int KK>
633+
template <typename FPTYPE, int MTILE, int KTILE>
634634
__global__ void tabulate_fusion_se_t_tebd_fifth_order_polynomial(
635635
FPTYPE* out,
636636
const FPTYPE* table,
@@ -668,7 +668,7 @@ __global__ void tabulate_fusion_se_t_tebd_fifth_order_polynomial(
668668
}
669669
}
670670

671-
template <typename FPTYPE, int MM, int KK>
671+
template <typename FPTYPE, int MTILE, int KTILE>
672672
__global__ void tabulate_fusion_se_t_tebd_grad_fifth_order_polynomial(
673673
FPTYPE* dy_dem_x,
674674
const FPTYPE* table,

source/op/pt/tabulate_multi_device.cc

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -389,8 +389,7 @@ void TabulateFusionSeTTebdGradForward(const torch::Tensor& table_tensor,
389389
const torch::Tensor& em_tensor,
390390
const torch::Tensor& dy_tensor,
391391
const torch::Tensor& descriptor_tensor,
392-
torch::Tensor& dy_dem_x_tensor,
393-
torch::Tensor& dy_dem_tensor) {
392+
torch::Tensor& dy_dem_x_tensor) {
394393
// check input shape
395394
if (dy_tensor.dim() != 4) {
396395
throw std::invalid_argument("Dim of dy_tensor should be 4");
@@ -399,7 +398,6 @@ void TabulateFusionSeTTebdGradForward(const torch::Tensor& table_tensor,
399398
GetTensorDevice(table_tensor, device);
400399
// flat the tensors
401400
FPTYPE* dy_dem_x = dy_dem_x_tensor.view({-1}).data_ptr<FPTYPE>();
402-
FPTYPE* dy_dem = dy_dem_tensor.view({-1}).data_ptr<FPTYPE>();
403401

404402
const FPTYPE* table = table_tensor.view({-1}).data_ptr<FPTYPE>();
405403
const FPTYPE* table_info = table_info_tensor.view({-1}).data_ptr<FPTYPE>();
@@ -430,6 +428,54 @@ void TabulateFusionSeTTebdGradForward(const torch::Tensor& table_tensor,
430428
}
431429
}
432430

431+
template <typename FPTYPE>
432+
void TabulateFusionSeTTebdGradGradForward(const torch::Tensor& table_tensor,
433+
const torch::Tensor& table_info_tensor,
434+
const torch::Tensor& em_x_tensor,
435+
const torch::Tensor& em_tensor,
436+
const torch::Tensor& dz_dy_dem_x_tensor,
437+
const torch::Tensor& descriptor_tensor,
438+
torch::Tensor& dz_dy_tensor) {
439+
// Check input shape
440+
if (dz_dy_dem_x_tensor.dim() != 3) {
441+
throw std::invalid_argument("Dim of dz_dy_dem_x should be 3");
442+
}
443+
// get the device
444+
std::string device;
445+
GetTensorDevice(table_tensor, device);
446+
// flat the tensors
447+
FPTYPE* dz_dy = dz_dy_tensor.view({-1}).data_ptr<FPTYPE>();
448+
449+
const FPTYPE* table = table_tensor.view({-1}).data_ptr<FPTYPE>();
450+
const FPTYPE* table_info = table_info_tensor.view({-1}).data_ptr<FPTYPE>();
451+
const FPTYPE* em_x = em_x_tensor.view({-1}).data_ptr<FPTYPE>();
452+
const FPTYPE* em = em_tensor.view({-1}).data_ptr<FPTYPE>();
453+
const FPTYPE* dz_dy_dem_x = dz_dy_dem_x_tensor.view({-1}).data_ptr<FPTYPE>();
454+
const int64_t nloc = em_x_tensor.size(0);
455+
const int64_t nnei_i = em_x_tensor.size(1);
456+
const int64_t nnei_j = em_x_tensor.size(2);
457+
const int64_t last_layer_size = descriptor_tensor.size(3);
458+
// compute
459+
if (device == "GPU") {
460+
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
461+
deepmd::tabulate_fusion_se_t_tebd_grad_grad_gpu(dz_dy, table, table_info, em_x,
462+
em, dz_dy_dem_x, nloc,
463+
nnei_i, nnei_j, last_layer_size);
464+
#else
465+
throw std::runtime_error(
466+
"The input tensor is on the GPU, but the GPU support for the "
467+
"customized OP library is not enabled.");
468+
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
469+
TORCH_CHECK(last_layer_size <= 1024,
470+
"In the process of model compression, the size of the "
471+
"last layer of embedding net must be less than 1024!");
472+
} else if (device == "CPU") {
473+
deepmd::tabulate_fusion_se_t_tebd_grad_grad_cpu(dz_dy, table, table_info, em_x,
474+
em, dz_dy_dem_x, nloc,
475+
nnei_i, nnei_j, last_layer_size);
476+
}
477+
}
478+
433479
template <typename FPTYPE>
434480
void TabulateFusionSeRForward(const torch::Tensor& table_tensor,
435481
const torch::Tensor& table_info_tensor,
@@ -1107,13 +1153,12 @@ class TabulateFusionSeTTebdOp
11071153
torch::Tensor dy_tensor = grad_output[0].contiguous();
11081154
// allocate output tensors
11091155
torch::Tensor dy_dem_x_tensor = torch::zeros_like(em_x_tensor);
1110-
torch::Tensor dy_dem_tensor = torch::zeros_like(em_tensor);
11111156
// compute
11121157
TabulateFusionSeTTebdGradForward<FPTYPE>(
11131158
table_tensor, table_info_tensor, em_x_tensor, em_tensor, dy_tensor,
1114-
descriptor_tensor, dy_dem_x_tensor, dy_dem_tensor);
1159+
descriptor_tensor, dy_dem_x_tensor);
11151160

1116-
return {at::Tensor(), at::Tensor(), dy_dem_x_tensor, dy_dem_tensor,
1161+
return {at::Tensor(), at::Tensor(), dy_dem_x_tensor, at::Tensor(),
11171162
at::Tensor()};
11181163
}
11191164
};

0 commit comments

Comments
 (0)