Skip to content

Commit 83165db

Browse files
committed
fix(pt): Remove existing type_embd_data before registering new buffer in DescrptBlockSeAtten
1 parent bffe486 commit 83165db

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

deepmd/pt/model/descriptor/se_atten.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,8 @@ def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None:
480480
# One-side: only neighbor types, much simpler!
481481
# Precompute for all (ntypes+1) neighbor types
482482
embd_tensor = self.filter_layers_strip.networks[0](full_embd).detach()
483+
if hasattr(self, "type_embd_data"):
484+
del self.type_embd_data
483485
self.register_buffer("type_embd_data", embd_tensor)
484486
else:
485487
# Two-side: all (ntypes+1)² type pair combinations
@@ -496,6 +498,8 @@ def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None:
496498
embd_tensor = self.filter_layers_strip.networks[0](
497499
two_side_embd
498500
).detach()
501+
if hasattr(self, "type_embd_data"):
502+
del self.type_embd_data
499503
self.register_buffer("type_embd_data", embd_tensor)
500504

501505
def forward(

0 commit comments

Comments
 (0)