Skip to content

Commit cfcbb7b

Browse files
committed
add UT and enable compress in forward
1 parent 7b1690a commit cfcbb7b

2 files changed

Lines changed: 405 additions & 2 deletions

File tree

deepmd/pt/model/descriptor/se_t_tebd.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -950,8 +950,22 @@ def forward(
950950
# nfnl x nt_i x nt_j x ng
951951
gg = self.filter_layers.networks[0](ss)
952952
elif self.tebd_input_mode in ["strip"]:
953-
# nfnl x nt_i x nt_j x ng
954-
gg_s = self.filter_layers.networks[0](ss)
953+
if self.compress:
954+
# Use tabulated computation for the geometric embedding
955+
ebd_env_ij = env_ij.view(-1, 1)
956+
gg_s = torch.ops.deepmd.tabulate_fusion_se_t(
957+
self.compress_data[0].contiguous(),
958+
self.compress_info[0].cpu().contiguous(),
959+
ebd_env_ij.contiguous(),
960+
env_ij.contiguous(),
961+
self.filter_neuron[-1],
962+
)[0]
963+
# Reshape back to the expected format: nfnl x nt_i x nt_j x ng
964+
gg_s = gg_s.view(nfnl, nnei, nnei, self.filter_neuron[-1])
965+
else:
966+
# nfnl x nt_i x nt_j x ng
967+
gg_s = self.filter_layers.networks[0](ss)
968+
955969
assert self.filter_layers_strip is not None
956970
assert type_embedding is not None
957971
ng = self.filter_neuron[-1]

0 commit comments

Comments
 (0)