Skip to content

Commit e9f5058

Browse files
committed
perf: use torch.embedding for type embedding
1 parent 43e0288 commit e9f5058

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

deepmd/pt/model/network/network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def forward(self, atype):
286286
type_embedding:
287287
288288
"""
289-
return self.embedding(atype.device)[atype]
289+
return torch.embedding(self.embedding(atype.device), atype)
290290

291291
def get_full_embedding(self, device: torch.device):
292292
"""

0 commit comments

Comments
 (0)