Skip to content

Commit 8176173

Browse files
authored
perf: use torch.embedding for type embedding (deepmodeling#4747)
The backward step of indexing operation is costly on GPU. Using dedicated `torch.embedding` mitigates this problem. <details><summary>Profiling results</summary> <p> Before: 32ms <img width="612" alt="image" src="https://github.com/user-attachments/assets/6e2a4de1-433a-4b6a-8b59-a8458d66897c" /> --- After: 0.5ms <img width="334" alt="image" src="https://github.com/user-attachments/assets/199ac925-8382-4a43-a6bb-584bb60159b2" /> </p> </details> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Refactor** - Updated the embedding lookup mechanism in the model, potentially improving how embeddings are retrieved internally. No changes to the user interface or method signatures. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 30b762e commit 8176173

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

deepmd/pt/model/network/network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def forward(self, atype):
286286
type_embedding:
287287
288288
"""
289-
return self.embedding(atype.device)[atype]
289+
return torch.embedding(self.embedding(atype.device), atype)
290290

291291
def get_full_embedding(self, device: torch.device):
292292
"""

0 commit comments

Comments
 (0)