diff --git a/deepmd/pt/model/network/network.py b/deepmd/pt/model/network/network.py index 8a2bd4f75c..ab01a90774 100644 --- a/deepmd/pt/model/network/network.py +++ b/deepmd/pt/model/network/network.py @@ -286,7 +286,7 @@ def forward(self, atype): type_embedding: """ - return self.embedding(atype.device)[atype] + return torch.embedding(self.embedding(atype.device), atype) def get_full_embedding(self, device: torch.device): """