Skip to content

Commit c42ffe5

Browse files
committed
refactor(pt): Optimize type embedding handling in se_atten
- Replaced direct assignment of `type_embd_data` with a call to `register_buffer` for better memory management. - Improved clarity by using a temporary variable `embd_tensor` for storing the output of the embedding network before registration. - Maintained functionality for both one-side and two-side type embeddings, ensuring consistent behavior across modes. These changes enhance the maintainability and performance of the descriptor model.
1 parent 4f72994 commit c42ffe5

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

deepmd/pt/model/descriptor/se_atten.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -479,9 +479,8 @@ def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None:
479479
if self.type_one_side:
480480
# One-side: only neighbor types, much simpler!
481481
# Precompute for all (ntypes+1) neighbor types
482-
self.type_embd_data = self.filter_layers_strip.networks[0](
483-
full_embd
484-
).detach()
482+
embd_tensor = self.filter_layers_strip.networks[0](full_embd).detach()
483+
self.register_buffer("type_embd_data", embd_tensor)
485484
else:
486485
# Two-side: all (ntypes+1)² type pair combinations
487486
# Create [neighbor, center] combinations
@@ -494,9 +493,10 @@ def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None:
494493
)
495494
# Precompute for all type pairs
496495
# Index formula: idx = center_type * nt + neighbor_type
497-
self.type_embd_data = self.filter_layers_strip.networks[0](
496+
embd_tensor = self.filter_layers_strip.networks[0](
498497
two_side_embd
499498
).detach()
499+
self.register_buffer("type_embd_data", embd_tensor)
500500

501501
def forward(
502502
self,

0 commit comments

Comments
 (0)